diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 9305674a..af3b7bba 100644 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -32,23 +32,36 @@ def __init__(self, bit, symmetric, granularity, **kwargs): self.round_zp = self.kwargs.get('round_zp', True) self.sigmoid = torch.nn.Sigmoid() - def get_tensor_range(self, tensor, args={}): - if self.calib_algo == 'minmax': - return self.get_minmax_range(tensor) - elif self.calib_algo == 'mse': - return self.get_mse_range(tensor) - elif self.calib_algo == 'learnable': - return self.get_learnable_range(tensor, **args) + # mse config + self.maxshrink = self.kwargs.get('maxshrink', 0.8) + self.mse_grid = self.kwargs.get('mse_grid', 100) + + # hist config + self.bins = self.kwargs.get('bins', 2048) + self.hist_threshold = self.kwargs.get('hist_threshold', 1) + + # hqq config + self.lp_norm = self.kwargs.get('lp_norm', 0.7) + self.beta = self.kwargs.get('beta', 10) + self.kappa = self.kwargs.get('kappa', 1.01) + self.iters = self.kwargs.get('iters', 20) + if self.lp_norm == 1: + self.shrink_op = lambda x, beta: torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - 1.0 / self.beta + ) else: - raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}') + self.shrink_op = lambda x, beta, p=self.lp_norm: torch.sign( + x + ) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / self.beta) * torch.pow(torch.abs(x), p - 1) + ) - def get_running_tensor_range(self, act_tensors, alpha, args): + def reshape_batch_tensors(self, act_tensors): assert len(act_tensors) > 0, ( 'Calibration data is insufficient. Please provide more data to ensure ' 'all experts in the MOE receive an adequate number of tokens.' ) - runing_min_vals, runing_max_vals = [], [] if isinstance(act_tensors[0], tuple): # Handle multiple inputs by stacking tensors. unzipped_inputs = zip(*act_tensors) @@ -60,27 +73,45 @@ def get_running_tensor_range(self, act_tensors, alpha, args): act_tensors[0] = tensor_list else: act_tensors = [act_tensors] + return act_tensors - for tensors in act_tensors: - runing_min_val, runing_max_val = None, None - for tensor in tensors: - tensor = self.reshape_tensor(tensor) - tensor_range = self.get_tensor_range(tensor, args) - min_val, max_val = tensor_range[0], tensor_range[1] - - if runing_min_val is None or runing_max_val is None: - runing_min_val = min_val - runing_max_val = max_val - else: - runing_min_val = runing_min_val + alpha * ( - min_val - runing_min_val - ) - runing_max_val = runing_max_val + alpha * ( - max_val - runing_max_val - ) - runing_min_vals.append(runing_min_val) - runing_max_vals.append(runing_max_val) + def get_tensor_range(self, tensor, args={}): + if self.calib_algo == 'minmax': + return self.get_minmax_range(tensor) + elif self.calib_algo == 'mse': + return self.get_mse_range(tensor) + elif self.calib_algo == 'learnable': + return self.get_learnable_range(tensor, **args) + else: + raise ValueError(f'Unsupported calibration algorithm: {self.calib_algo}') + def get_hist_range(self, stats_min_max, act_stats_hist): + clip_val = {} + for input_idx, hist in act_stats_hist.items(): + hist = hist.float() / hist.sum() + data_max = max( + -torch.min(stats_min_max[input_idx]['min']), + torch.max(stats_min_max[input_idx]['max']), + ) + accum = 0 + for i in range(len(hist)): + accum += hist[i] + if accum >= self.hist_threshold: + clip_value = (i + 0.5) * (data_max / self.bins) + clip_val[input_idx] = [ + max(-clip_value, torch.min(stats_min_max[input_idx]['min'])), + min(clip_value, torch.max(stats_min_max[input_idx]['max'])), + ] + break + if input_idx not in clip_val: + clip_val[input_idx] = [ + torch.min(stats_min_max[input_idx]['min']), + torch.max(stats_min_max[input_idx]['max']), + ] + runing_min_vals, runing_max_vals = [], [] + for input_idx, tensor_range in clip_val.items(): + runing_min_vals.append(tensor_range[0]) + runing_max_vals.append(tensor_range[1]) return runing_min_vals, runing_max_vals def get_minmax_range(self, tensor): @@ -93,7 +124,7 @@ def get_minmax_range(self, tensor): return (min_val, max_val) - def get_mse_range(self, tensor, grid=100, norm=2.4, maxshrink=0.8, bs=256): + def get_mse_range(self, tensor, norm=2.4, bs=256): assert ( self.mse_b_num >= 1 and tensor.shape[0] % self.mse_b_num == 0 @@ -115,8 +146,8 @@ def get_mse_range(self, tensor, grid=100, norm=2.4, maxshrink=0.8, bs=256): best_min_val, best_max_val = _min_val, _max_val - for i in range(int(maxshrink * grid)): - p = 1 - i / grid + for i in range(int(self.maxshrink * self.mse_grid)): + p = 1 - i / self.mse_grid xmin = p * _min_val xmax = p * _max_val @@ -169,6 +200,148 @@ def get_learnable_range(self, tensor, lowbound_factor=None, upbound_factor=None) return (min_val, max_val) + def get_minmax_stats(self, act_tensors): + stats_min_max = {} + for input_idx, tensors in enumerate(act_tensors): + for tensor in tensors: + tensor = self.reshape_tensor(tensor) + tensor_range = self.get_minmax_range(tensor) + min_val, max_val = tensor_range[0], tensor_range[1] + + if input_idx not in stats_min_max: + stats_min_max[input_idx] = {} + stats_min_max[input_idx]['min'] = torch.tensor( + [min_val], dtype=torch.float32 + ) + stats_min_max[input_idx]['max'] = torch.tensor( + [max_val], dtype=torch.float32 + ) + else: + stats_min_max[input_idx]['min'] = torch.cat( + [ + stats_min_max[input_idx]['min'], + torch.tensor([min_val], dtype=torch.float32), + ] + ) + stats_min_max[input_idx]['max'] = torch.cat( + [ + stats_min_max[input_idx]['max'], + torch.tensor([max_val], dtype=torch.float32), + ] + ) + + return stats_min_max + + def get_static_minmax_range(self, act_tensors): + + act_tensors = self.reshape_batch_tensors(act_tensors) + stats_min_max = self.get_minmax_stats(act_tensors) + min_vals, max_vals = [], [] + + for input_idx, tensor_range in stats_min_max.items(): + min_val = tensor_range['min'].mean() + max_val = tensor_range['max'].mean() + min_vals.append(min_val) + max_vals.append(max_val) + + return min_vals, max_vals + + def get_static_hist_range(self, act_tensors): + + act_tensors = self.reshape_batch_tensors(act_tensors) + + stats_min_max = stats_min_max = self.get_minmax_stats(act_tensors) + act_stats_hist = {} + for input_idx, tensors in enumerate(act_tensors): + for tensor in tensors: + data_max = max( + torch.max(stats_min_max[input_idx]['max']), + -torch.min(stats_min_max[input_idx]['min']), + ) + hist = torch.histc( + torch.abs(tensor), bins=int(self.bins), min=0, max=data_max + ) + + if input_idx not in act_stats_hist: + act_stats_hist[input_idx] = [hist] + else: + act_stats_hist[input_idx].append(hist) + + for input_idx, hist in act_stats_hist.items(): + act_stats_hist[input_idx] = torch.stack(hist).sum(0) + + runing_min_vals, runing_max_vals = self.get_hist_range( + stats_min_max, act_stats_hist + ) + return runing_min_vals, runing_max_vals + + def get_static_runing_minmax_range(self, act_tensors, alpha): + act_tensors = self.reshape_batch_tensors(act_tensors) + runing_min_vals, runing_max_vals = [], [] + for tensors in act_tensors: + runing_min_val, runing_max_val = None, None + for tensor in tensors: + tensor = self.reshape_tensor(tensor) + tensor_range = self.get_minmax_range(tensor) + min_val, max_val = tensor_range[0], tensor_range[1] + + if runing_min_val is None or runing_max_val is None: + runing_min_val = min_val + runing_max_val = max_val + else: + runing_min_val = runing_min_val + alpha * (min_val - runing_min_val) + runing_max_val = runing_max_val + alpha * (max_val - runing_max_val) + runing_min_vals.append(runing_min_val) + runing_max_vals.append(runing_max_val) + + return runing_min_vals, runing_max_vals + + def get_static_mse_range(self, act_tensors, norm=2.4): + act_tensors = self.reshape_batch_tensors(act_tensors) + stats_min_max = self.get_minmax_stats(act_tensors) + best_min_vals, best_max_vals = [], [] + + for input_idx, tensor_range in stats_min_max.items(): + _min_val = tensor_range['min'].mean() + _max_val = tensor_range['max'].mean() + _tensor = torch.stack(act_tensors[input_idx]).float() + + best = float('inf') + best_min_val, best_max_val = _min_val, _max_val + dev = _tensor.device + + for i in range(int(self.maxshrink * self.mse_grid)): + p = 1 - i / self.mse_grid + + xmin = p * _min_val + xmax = p * _max_val + + if self.quant_type == 'float-quant' and not self.use_qtorch: + clip_tensor, scales = self.get_float_qparams( + _tensor, (xmin, xmax), dev + ) + zeros, qmin, qmax = 0, None, None + q_tensor = self.quant_dequant( + clip_tensor, scales, zeros, qmax, qmin + ) + + else: + scales, zeros, qmax, qmin = self.get_qparams((xmin, xmax), dev) + q_tensor = self.quant_dequant(_tensor, scales, zeros, qmax, qmin) + + q_tensor -= _tensor + q_tensor.abs_() + q_tensor.pow_(norm) + err = torch.sum(q_tensor) + + if err < best: + best_min_val, best_max_val = xmin, xmax + + best_min_vals.append(best_min_val) + best_max_vals.append(best_max_val) + + return best_min_vals, best_max_vals + def get_qparams(self, tensor_range, device): min_val, max_val = tensor_range[0], tensor_range[1] qmin = self.qmin.to(device) @@ -185,21 +358,22 @@ def get_qparams(self, tensor_range, device): zeros = qmin - (min_val / scales) return scales, zeros, qmax, qmin - def get_tensor_qparams(self, tensor, args={}): - tensor = self.reshape_tensor(tensor) - tensor_range = self.get_tensor_range(tensor, args) - scales, zeros, qmax, qmin = self.get_qparams(tensor_range, tensor.device) - return tensor, scales, zeros, qmax, qmin - def get_batch_tensors_qparams(self, act_tensors, alpha=0.01, args={}): scales_list, zeros_list, qmin_list, qmax_list = [], [], [], [] - runing_min_vals, runing_max_vals = self.get_running_tensor_range( - act_tensors, alpha, args - ) - for i in range(len(runing_min_vals)): - runing_min_val, runing_max_val = runing_min_vals[i], runing_max_vals[i] + + if self.calib_algo == 'hist': + min_vals, max_vals = self.get_static_hist_range(act_tensors) + elif self.calib_algo == 'minmax': + min_vals, max_vals = self.get_static_minmax_range(act_tensors) + elif self.calib_algo == 'runing_minmax': + min_vals, max_vals = self.get_static_runing_minmax_range(act_tensors, alpha) + elif self.calib_algo == 'mse': + min_vals, max_vals = self.get_static_mse_range(act_tensors) + + for i in range(len(min_vals)): + min_val, max_val = min_vals[i], max_vals[i] scales, zeros, qmax, qmin = self.get_qparams( - (runing_min_val, runing_max_val), runing_min_val.device + (min_val, max_val), min_val.device ) scales_list.append(scales) zeros_list.append(zeros) @@ -208,6 +382,32 @@ def get_batch_tensors_qparams(self, act_tensors, alpha=0.01, args={}): return scales_list, zeros_list, qmin_list, qmax_list + def optimize_weights_proximal(self, tensor, scales, zeros, qmax, qmin): + best_error = 1e4 + current_beta = self.beta + current_kappa = self.kappa + scales = 1 / scales + for i in range(self.iters): + W_q = torch.round(tensor * scales + zeros).clamp(qmin, qmax) + W_r = (W_q - zeros) / scales + W_e = self.shrink_op(tensor - W_r, current_beta) + + zeros = torch.mean(W_q - (tensor - W_e) * scales, axis=-1, keepdim=True) + current_beta *= current_kappa + current_error = float(torch.abs(tensor - W_r).mean()) + + logger.info(f'iter : {i}, error : {current_error}') + + if current_error < best_error: + best_error = current_error + else: + break + + torch.cuda.empty_cache() + scales = 1 / scales + + return scales, zeros + def reshape_tensor(self, tensor, allow_padding=False): if self.granularity == 'per_group': if tensor.shape[-1] >= self.group_size: @@ -265,6 +465,25 @@ def __init__(self, bit, symmetric, granularity, **kwargs): self.qmin = torch.tensor(self.qmin) self.qmax = torch.tensor(self.qmax) + def get_hqq_qparams(self, tensor, args): + tensor = tensor.float() + tensor = self.reshape_tensor(tensor) + tensor_range = self.get_minmax_range(tensor) + scales, zeros, qmax, qmin = self.get_qparams(tensor_range, tensor.device) + best_scales, best_zeros = self.optimize_weights_proximal( + tensor, scales, zeros, qmax, qmin + ) + return tensor, best_scales, best_zeros, qmax, qmin + + def get_tensor_qparams(self, tensor, args={}): + if self.calib_algo == 'hqq': + return self.get_hqq_qparams(tensor, args) + else: + tensor = self.reshape_tensor(tensor) + tensor_range = self.get_tensor_range(tensor, args) + scales, zeros, qmax, qmin = self.get_qparams(tensor_range, tensor.device) + return tensor, scales, zeros, qmax, qmin + def quant(self, tensor, scales, zeros, qmax, qmin): if self.round_zp: tensor = torch.clamp(self.round_func(tensor / scales) + zeros, qmin, qmax) @@ -596,16 +815,37 @@ def get_float_qparams(self, tensor, tensor_range, device): return xc, scales - def get_tensor_qparams(self, tensor, args={}): + def get_hqq_qparams(self, tensor, args): + tensor = tensor.float() tensor = self.reshape_tensor(tensor) - tensor_range = self.get_tensor_range(tensor, args) + tensor_range = self.get_minmax_range(tensor) if self.use_qtorch: scales, zeros, qmax, qmin = self.get_qparams(tensor_range, tensor.device) else: tensor, scales = self.get_float_qparams(tensor, tensor_range, tensor.device) zeros, qmin, qmax = torch.tensor(0), None, None + best_scales, best_zeros = self.optimize_weights_proximal( + tensor, scales, zeros, qmax, qmin + ) + return tensor, best_scales, best_zeros, qmax, qmin + + def get_tensor_qparams(self, tensor, args={}): + if self.calib_algo == 'hqq': + return self.get_hqq_qparams(tensor, args) + else: + tensor = self.reshape_tensor(tensor) + tensor_range = self.get_tensor_range(tensor, args) + if self.use_qtorch: + scales, zeros, qmax, qmin = self.get_qparams( + tensor_range, tensor.device + ) + else: + tensor, scales = self.get_float_qparams( + tensor, tensor_range, tensor.device + ) + zeros, qmin, qmax = torch.tensor(0), None, None - return tensor, scales, zeros, qmax, qmin + return tensor, scales, zeros, qmax, qmin def quant(self, tensor, scales, zeros, qmax, qmin): scales[scales == 0] = 1