Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gqa smooth #279

Merged
merged 10 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 123 additions & 33 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ def __init__(self, model, quant_config, input, padding_mask, config):
self.trans = special_config.get('trans', True)
self.trans_version = special_config.get('trans_version', 'v2')
self.save_scale = special_config.get('save_scale', False)
self.awq_bs = special_config.get('awq_bs', None)

@torch.no_grad()
def scaling_weight(self, w, scales, is_gqa):
if is_gqa:
scales_tmp = self.repeat_gqa_scales(scales)
else:
scales_tmp = scales
w_tmp = w.mul_(scales_tmp.view(1, -1))
return w_tmp

@torch.no_grad()
def get_weight_scale(self, layers_dict):
Expand All @@ -49,20 +59,82 @@ def get_weight_scale(self, layers_dict):
torch.cuda.empty_cache()
return scale

@torch.no_grad()
def get_act_scale(self, x):
return x.abs().view(-1, x.shape[-1]).mean(0)
batch_means = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
batch_x = x[num * self._bs:(num + 1) * self._bs]
batch_mean = batch_x.abs().view(-1, batch_x.shape[-1]).mean(0)
batch_means.append(batch_mean)
final_mean = sum(batch_means) / len(batch_means)
return final_mean

@torch.no_grad()
def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
if is_gqa:
x_tmp = prev_op(x)
w_tmp = self.get_weight_scale({'prev_op': prev_op})
else:
x_tmp = x
w_tmp = w_max

x_tmp = self.get_act_scale(x_tmp)

if self.trans_version == 'v1':
scales = (
(x_tmp.pow(ratio) / w_tmp.pow(1 - ratio))
.clamp(min=1e-4)
.view(-1)
)
elif self.trans_version == 'v2':
scales = x_tmp.pow(ratio).clamp(min=1e-4).view(-1)

scales = scales / (scales.max() * scales.min()).sqrt()
return scales

def inspect_module_forward(self, x, inspect_module, kwargs):
outs = []
b_num = x.shape[0] // self._bs
for num in range(b_num):
_x = x[num * self._bs:(num + 1) * self._bs]
out = inspect_module(_x, **kwargs)
if isinstance(out, tuple):
out = out[0]
outs.append(out)
return torch.cat(outs, dim=0)

@torch.no_grad()
def get_original_out(self, x, inspect_module, subset_kwargs):
with torch.no_grad():
org_out = inspect_module(x, **subset_kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]
org_out = self.inspect_module_forward(x, inspect_module, subset_kwargs)
return org_out

def calculate_loss(self, org_out, out):
total_loss = 0.0
b_num = org_out.shape[0] // self._bs
for num in range(b_num):
_org_out = org_out[num * self._bs:(num + 1) * self._bs]
_out = out[num * self._bs:(num + 1) * self._bs]
single_loss = (_org_out - _out).float().pow(2).mean().item()
total_loss += single_loss
return total_loss / b_num

@torch.no_grad()
def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs):
def search_scale_subset(
self,
prev_op,
layers_dict,
input,
inspect_module,
is_gqa,
subset_kwargs
):

if self.awq_bs is None:
self._bs = input[0].shape[0]
else:
self._bs = self.awq_bs

w_max = self.get_weight_scale(layers_dict)
# grid search for ratio
best_error = float('inf')
Expand All @@ -89,18 +161,10 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
x_max = self.get_act_scale(x)

ratio = n * 1 / n_grid
if self.trans_version == 'v1':
scales = (
(x_max.pow(ratio) / w_max.pow(1 - ratio))
.clamp(min=1e-4)
.view(-1)
)
elif self.trans_version == 'v2':
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
scales = self.get_scales(prev_op, x, w_max, is_gqa, ratio)
for layer_name in layers_dict:
fc = layers_dict[layer_name]
fc.weight.mul_(scales.view(1, -1))
fc.weight = self.scaling_weight(fc.weight, scales, is_gqa)

fc.weight.data = get_wquantizer(
self.block_idx,
Expand All @@ -110,31 +174,39 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
self.wquantizer,
).fake_quant_weight_dynamic(fc.weight.data)

x_tmp = x / scales.view(1, -1)
del x_max
gc.collect()
torch.cuda.empty_cache()

x_tmp = self.scaling_input(x, scales, is_gqa)

if not check_w_only(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.w_only,
):
x_tmp = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(x_tmp)
out = inspect_module(x_tmp, **kwargs)

if isinstance(out, tuple):
out = out[0]
outs = []
for i in range(x_tmp.shape[0]):
_x = x_tmp[i]
_x = get_aquantizer(
self.block_idx,
list(layers_dict.keys())[0],
self.mix_bits_map,
self.quantizer_mix_bits,
self.aquantizer,
).fake_quant_act_dynamic(_x)
outs.append(_x)
x_tmp = torch.stack(outs)

out = self.inspect_module_forward(x_tmp, inspect_module, kwargs)

if self.padding_mask and org_out.shape[1] == self.padding_mask[i].shape[-1]:
org_out = org_out * self.padding_mask[i].unsqueeze(dim=-1).to(org_out.device) # noqa
out = out * self.padding_mask[i].unsqueeze(dim=-1).to(out.device)

loss = (org_out - out).float().pow(2).mean().item()
loss = self.calculate_loss(org_out, out)

if len(input) == 1:
n_samples = x.shape[0]
Expand All @@ -149,6 +221,11 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
best_error = loss_mean
best_scales = scales_mean

del org_out
del out
gc.collect()
torch.cuda.empty_cache()

# Synchronize across ranks
best_error_tensor = torch.tensor([best_error], device='cuda')
dist.all_reduce(best_error_tensor, op=dist.ReduceOp.MIN)
Expand Down Expand Up @@ -248,15 +325,28 @@ def subset_transform(
and prev_op[0].out_features != layers[0].in_features * 2
and prev_op[0].out_features != layers[0].in_features
):
logger.info('Cannot apply scale. Do not transform this subset.')
return

if self.has_gqa:
is_gqa = True
input_keys = list(input_feat.keys())
input_name = input_keys[input_keys.index(input_name) - 1]
else:
logger.info('Cannot apply scale. Do not transform this subset.')
return
else:
is_gqa = False

scale = self.search_scale_subset(
layers_dict, input_feat[input_name], inspect_module, subset_kwargs
prev_op[0],
layers_dict,
input_feat[input_name],
inspect_module,
is_gqa,
subset_kwargs
)

self.apply_scale(scale, prev_op, layers)
self.update_input_feat(scale, input_feat, layers_dict)
self.update_input_feat(scale, input_feat, layers_dict, is_gqa)

if self.save_scale:
for n in layers_dict:
Expand Down
56 changes: 48 additions & 8 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,26 @@ def set_quant_config(self):
assert self.config['model']['type'] in ['Opt', 'Llama']

self.hidden_size = self.model.model_config.hidden_size
if self.online_rotate:
self.num_heads = self.model.model_config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.intermediate_size = self.model.model_config.intermediate_size
self.fp32_had = special_config.get('fp32_had', False)

self.set_model_config()
self.quant_objects = self.quant_config.get('quant_objects', ['language'])
logger.info(f'self.quant_objects : {self.quant_objects}')

def set_model_config(self):
self.hidden_size = self.model.model_config.hidden_size
self.num_heads = self.model.model_config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
if hasattr(self.model.model_config, 'intermediate_size'):
self.intermediate_size = self.model.model_config.intermediate_size
if hasattr(self.model.model_config, 'num_key_value_heads'):
self.num_key_value_heads = self.model.model_config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
if self.num_key_value_groups > 1:
self.has_gqa = True
else:
self.has_gqa = False
else:
self.has_gqa = False

def replace_rotate_linears(self, block):
for n, m in block.named_modules():
if isinstance(m, nn.Linear) and (
Expand Down Expand Up @@ -581,6 +592,12 @@ def register_act_qparams(self, layers_dict, act_tensors):
layer.register_buffer(f'buf_act_qmin_{i}', qmin.cuda())
layer.register_buffer(f'buf_act_qmax_{i}', qmax.cuda())

@torch.no_grad()
def repeat_gqa_scales(self, scales):
scales = scales.view(1, self.num_key_value_heads, self.head_dim)
scales = torch.repeat_interleave(scales, dim=1, repeats=self.num_key_value_groups)
return scales

@torch.no_grad()
def apply_scale(self, scales, prev_op, layers):
assert (
Expand Down Expand Up @@ -652,6 +669,14 @@ def scale_fc_fc(self, fc1, fc2, scales):
fc1.bias.div_(scales.view(-1))

fc1.weight.div_(scales.view(-1, 1))
elif self.has_gqa:
if hasattr(fc1, 'bias') and fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
fc1.weight.div_(scales.view(-1, 1))

if fc1.out_features != fc2.in_features:
logger.info('GQA scale this fc-fc.')
scales = self.repeat_gqa_scales(scales)
else:
logger.error(f'fc1.out_features: {fc1.out_features}')
logger.error(f'fc2.in_features: {fc2.in_features}')
Expand Down Expand Up @@ -795,11 +820,26 @@ def bake_mean_into_fc(self, fc):
fc.bias.data = fc.bias.data.to(fc_dtype)

@torch.no_grad()
def update_input_feat(self, scale, input_feat, layers_dict):
def scaling_input(self, x, scales, is_gqa):
if is_gqa:
scales_tmp = self.repeat_gqa_scales(scales)
else:
scales_tmp = scales

x_tmp = torch.empty_like(x)
for i, batch in enumerate(x):
batch_scale = scales_tmp.view(1, -1)
x_tmp[i] = batch / batch_scale

return x_tmp

@torch.no_grad()
def update_input_feat(self, scale, input_feat, layers_dict, is_gqa):
for layer_name in layers_dict:
for i in range(len(input_feat[layer_name])):
inp = input_feat[layer_name][i]
inp.div_(scale.view(1, -1).to(inp.device))
scale = scale.to(inp.device)
inp = self.scaling_input(inp, scale, is_gqa)

@torch.no_grad()
def set_non_linear_mode(self, quant_format, module, mode):
Expand Down
Loading
Loading