Skip to content

Commit e320149

Browse files
gushiqiaogushiqiao
andauthored
Support gqa smooth (#279)
* Support gqa smooth * Support gqa smooth * Support gqa smooth * FIx bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug --------- Co-authored-by: gushiqiao <[email protected]>
1 parent 75f8d1d commit e320149

File tree

4 files changed

+488
-133
lines changed

4 files changed

+488
-133
lines changed

llmc/compression/quantization/awq.py

Lines changed: 123 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ def __init__(self, model, quant_config, input, padding_mask, config):
2323
self.trans = special_config.get('trans', True)
2424
self.trans_version = special_config.get('trans_version', 'v2')
2525
self.save_scale = special_config.get('save_scale', False)
26+
self.awq_bs = special_config.get('awq_bs', None)
27+
28+
@torch.no_grad()
29+
def scaling_weight(self, w, scales, is_gqa):
30+
if is_gqa:
31+
scales_tmp = self.repeat_gqa_scales(scales)
32+
else:
33+
scales_tmp = scales
34+
w_tmp = w.mul_(scales_tmp.view(1, -1))
35+
return w_tmp
2636

2737
@torch.no_grad()
2838
def get_weight_scale(self, layers_dict):
@@ -49,20 +59,82 @@ def get_weight_scale(self, layers_dict):
4959
torch.cuda.empty_cache()
5060
return scale
5161

52-
@torch.no_grad()
5362
def get_act_scale(self, x):
54-
return x.abs().view(-1, x.shape[-1]).mean(0)
63+
batch_means = []
64+
b_num = x.shape[0] // self._bs
65+
for num in range(b_num):
66+
batch_x = x[num * self._bs:(num + 1) * self._bs]
67+
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
68+
batch_means.append(batch_mean)
69+
final_mean = sum(batch_means) / len(batch_means)
70+
return final_mean
71+
72+
@torch.no_grad()
73+
def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
74+
if is_gqa:
75+
x_tmp = prev_op(x)
76+
w_tmp = self.get_weight_scale({'prev_op': prev_op})
77+
else:
78+
x_tmp = x
79+
w_tmp = w_max
80+
81+
x_tmp = self.get_act_scale(x_tmp)
82+
83+
if self.trans_version == 'v1':
84+
scales = (
85+
(x_tmp.pow(ratio) / w_tmp.pow(1 - ratio))
86+
.clamp(min=1e-4)
87+
.view(-1)
88+
)
89+
elif self.trans_version == 'v2':
90+
scales = x_tmp.pow(ratio).clamp(min=1e-4).view(-1)
91+
92+
scales = scales / (scales.max() * scales.min()).sqrt()
93+
return scales
94+
95+
def inspect_module_forward(self, x, inspect_module, kwargs):
96+
outs = []
97+
b_num = x.shape[0] // self._bs
98+
for num in range(b_num):
99+
_x = x[num * self._bs:(num + 1) * self._bs]
100+
out = inspect_module(_x, **kwargs)
101+
if isinstance(out, tuple):
102+
out = out[0]
103+
outs.append(out)
104+
return torch.cat(outs, dim=0)
55105

56106
@torch.no_grad()
57107
def get_original_out(self, x, inspect_module, subset_kwargs):
58108
with torch.no_grad():
59-
org_out = inspect_module(x, **subset_kwargs)
60-
if isinstance(org_out, tuple):
61-
org_out = org_out[0]
109+
org_out = self.inspect_module_forward(x, inspect_module, subset_kwargs)
62110
return org_out
63111

112+
def calculate_loss(self, org_out, out):
113+
total_loss = 0.0
114+
b_num = org_out.shape[0] // self._bs
115+
for num in range(b_num):
116+
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
117+
_out = out[num * self._bs:(num + 1) * self._bs]
118+
single_loss = (_org_out - _out).float().pow(2).mean().item()
119+
total_loss += single_loss
120+
return total_loss / b_num
121+
64122
@torch.no_grad()
65-
def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs):
123+
def search_scale_subset(
124+
self,
125+
prev_op,
126+
layers_dict,
127+
input,
128+
inspect_module,
129+
is_gqa,
130+
subset_kwargs
131+
):
132+
133+
if self.awq_bs is None:
134+
self._bs = input[0].shape[0]
135+
else:
136+
self._bs = self.awq_bs
137+
66138
w_max = self.get_weight_scale(layers_dict)
67139
# grid search for ratio
68140
best_error = float('inf')
@@ -89,18 +161,10 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
89161
x_max = self.get_act_scale(x)
90162

91163
ratio = n * 1 / n_grid
92-
if self.trans_version == 'v1':
93-
scales = (
94-
(x_max.pow(ratio) / w_max.pow(1 - ratio))
95-
.clamp(min=1e-4)
96-
.view(-1)
97-
)
98-
elif self.trans_version == 'v2':
99-
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
100-
scales = scales / (scales.max() * scales.min()).sqrt()
164+
scales = self.get_scales(prev_op, x, w_max, is_gqa, ratio)
101165
for layer_name in layers_dict:
102166
fc = layers_dict[layer_name]
103-
fc.weight.mul_(scales.view(1, -1))
167+
fc.weight = self.scaling_weight(fc.weight, scales, is_gqa)
104168

105169
fc.weight.data = get_wquantizer(
106170
self.block_idx,
@@ -110,31 +174,39 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
110174
self.wquantizer,
111175
).fake_quant_weight_dynamic(fc.weight.data)
112176

113-
x_tmp = x / scales.view(1, -1)
177+
del x_max
178+
gc.collect()
179+
torch.cuda.empty_cache()
180+
181+
x_tmp = self.scaling_input(x, scales, is_gqa)
182+
114183
if not check_w_only(
115184
self.block_idx,
116185
list(layers_dict.keys())[0],
117186
self.mix_bits_map,
118187
self.quantizer_mix_bits,
119188
self.w_only,
120189
):
121-
x_tmp = get_aquantizer(
122-
self.block_idx,
123-
list(layers_dict.keys())[0],
124-
self.mix_bits_map,
125-
self.quantizer_mix_bits,
126-
self.aquantizer,
127-
).fake_quant_act_dynamic(x_tmp)
128-
out = inspect_module(x_tmp, **kwargs)
129-
130-
if isinstance(out, tuple):
131-
out = out[0]
190+
outs = []
191+
for i in range(x_tmp.shape[0]):
192+
_x = x_tmp[i]
193+
_x = get_aquantizer(
194+
self.block_idx,
195+
list(layers_dict.keys())[0],
196+
self.mix_bits_map,
197+
self.quantizer_mix_bits,
198+
self.aquantizer,
199+
).fake_quant_act_dynamic(_x)
200+
outs.append(_x)
201+
x_tmp = torch.stack(outs)
202+
203+
out = self.inspect_module_forward(x_tmp, inspect_module, kwargs)
132204

133205
if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]:
134206
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
135207
out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device)
136208

137-
loss = (org_out - out).float().pow(2).mean().item()
209+
loss = self.calculate_loss(org_out, out)
138210

139211
if len(input) == 1:
140212
n_samples = x.shape[0]
@@ -149,6 +221,11 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
149221
best_error = loss_mean
150222
best_scales = scales_mean
151223

224+
del org_out
225+
del out
226+
gc.collect()
227+
torch.cuda.empty_cache()
228+
152229
# Synchronize across ranks
153230
best_error_tensor = torch.tensor([best_error], device='cuda')
154231
dist.all_reduce(best_error_tensor, op=dist.ReduceOp.MIN)
@@ -248,15 +325,28 @@ def subset_transform(
248325
and prev_op[0].out_features != layers[0].in_features * 2
249326
and prev_op[0].out_features != layers[0].in_features
250327
):
251-
logger.info('Cannot apply scale. Do not transform this subset.')
252-
return
328+
329+
if self.has_gqa:
330+
is_gqa = True
331+
input_keys = list(input_feat.keys())
332+
input_name = input_keys[input_keys.index(input_name) - 1]
333+
else:
334+
logger.info('Cannot apply scale. Do not transform this subset.')
335+
return
336+
else:
337+
is_gqa = False
253338

254339
scale = self.search_scale_subset(
255-
layers_dict, input_feat[input_name], inspect_module, subset_kwargs
340+
prev_op[0],
341+
layers_dict,
342+
input_feat[input_name],
343+
inspect_module,
344+
is_gqa,
345+
subset_kwargs
256346
)
257347

258348
self.apply_scale(scale, prev_op, layers)
259-
self.update_input_feat(scale, input_feat, layers_dict)
349+
self.update_input_feat(scale, input_feat, layers_dict, is_gqa)
260350

261351
if self.save_scale:
262352
for n in layers_dict:

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,15 +283,26 @@ def set_quant_config(self):
283283
assert self.config['model']['type'] in ['Opt', 'Llama']
284284

285285
self.hidden_size = self.model.model_config.hidden_size
286-
if self.online_rotate:
287-
self.num_heads = self.model.model_config.num_attention_heads
288-
self.head_dim = self.hidden_size // self.num_heads
289-
self.intermediate_size = self.model.model_config.intermediate_size
290-
self.fp32_had = special_config.get('fp32_had', False)
291-
286+
self.set_model_config()
292287
self.quant_objects = self.quant_config.get('quant_objects', ['language'])
293288
logger.info(f'self.quant_objects : {self.quant_objects}')
294289

290+
def set_model_config(self):
291+
self.hidden_size = self.model.model_config.hidden_size
292+
self.num_heads = self.model.model_config.num_attention_heads
293+
self.head_dim = self.hidden_size // self.num_heads
294+
if hasattr(self.model.model_config, 'intermediate_size'):
295+
self.intermediate_size = self.model.model_config.intermediate_size
296+
if hasattr(self.model.model_config, 'num_key_value_heads'):
297+
self.num_key_value_heads = self.model.model_config.num_key_value_heads
298+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
299+
if self.num_key_value_groups > 1:
300+
self.has_gqa = True
301+
else:
302+
self.has_gqa = False
303+
else:
304+
self.has_gqa = False
305+
295306
def replace_rotate_linears(self, block):
296307
for n, m in block.named_modules():
297308
if isinstance(m, nn.Linear) and (
@@ -581,6 +592,12 @@ def register_act_qparams(self, layers_dict, act_tensors):
581592
layer.register_buffer(f'buf_act_qmin_{i}', qmin.cuda())
582593
layer.register_buffer(f'buf_act_qmax_{i}', qmax.cuda())
583594

595+
@torch.no_grad()
596+
def repeat_gqa_scales(self, scales):
597+
scales = scales.view(1, self.num_key_value_heads, self.head_dim)
598+
scales = torch.repeat_interleave(scales, dim=1, repeats=self.num_key_value_groups)
599+
return scales
600+
584601
@torch.no_grad()
585602
def apply_scale(self, scales, prev_op, layers):
586603
assert (
@@ -652,6 +669,14 @@ def scale_fc_fc(self, fc1, fc2, scales):
652669
fc1.bias.div_(scales.view(-1))
653670

654671
fc1.weight.div_(scales.view(-1, 1))
672+
elif self.has_gqa:
673+
if hasattr(fc1, 'bias') and fc1.bias is not None:
674+
fc1.bias.div_(scales.view(-1))
675+
fc1.weight.div_(scales.view(-1, 1))
676+
677+
if fc1.out_features != fc2.in_features:
678+
logger.info('GQA scale this fc-fc.')
679+
scales = self.repeat_gqa_scales(scales)
655680
else:
656681
logger.error(f'fc1.out_features: {fc1.out_features}')
657682
logger.error(f'fc2.in_features: {fc2.in_features}')
@@ -795,11 +820,26 @@ def bake_mean_into_fc(self, fc):
795820
fc.bias.data = fc.bias.data.to(fc_dtype)
796821

797822
@torch.no_grad()
798-
def update_input_feat(self, scale, input_feat, layers_dict):
823+
def scaling_input(self, x, scales, is_gqa):
824+
if is_gqa:
825+
scales_tmp = self.repeat_gqa_scales(scales)
826+
else:
827+
scales_tmp = scales
828+
829+
x_tmp = torch.empty_like(x)
830+
for i, batch in enumerate(x):
831+
batch_scale = scales_tmp.view(1, -1)
832+
x_tmp[i] = batch / batch_scale
833+
834+
return x_tmp
835+
836+
@torch.no_grad()
837+
def update_input_feat(self, scale, input_feat, layers_dict, is_gqa):
799838
for layer_name in layers_dict:
800839
for i in range(len(input_feat[layer_name])):
801840
inp = input_feat[layer_name][i]
802-
inp.div_(scale.view(1, -1).to(inp.device))
841+
scale = scale.to(inp.device)
842+
inp = self.scaling_input(inp, scale, is_gqa)
803843

804844
@torch.no_grad()
805845
def set_non_linear_mode(self, quant_format, module, mode):

0 commit comments

Comments
 (0)