@@ -23,6 +23,16 @@ def __init__(self, model, quant_config, input, padding_mask, config):
23
23
self .trans = special_config .get ('trans' , True )
24
24
self .trans_version = special_config .get ('trans_version' , 'v2' )
25
25
self .save_scale = special_config .get ('save_scale' , False )
26
+ self .awq_bs = special_config .get ('awq_bs' , None )
27
+
28
+ @torch .no_grad ()
29
+ def scaling_weight (self , w , scales , is_gqa ):
30
+ if is_gqa :
31
+ scales_tmp = self .repeat_gqa_scales (scales )
32
+ else :
33
+ scales_tmp = scales
34
+ w_tmp = w .mul_ (scales_tmp .view (1 , - 1 ))
35
+ return w_tmp
26
36
27
37
@torch .no_grad ()
28
38
def get_weight_scale (self , layers_dict ):
@@ -49,20 +59,82 @@ def get_weight_scale(self, layers_dict):
49
59
torch .cuda .empty_cache ()
50
60
return scale
51
61
52
- @torch .no_grad ()
53
62
def get_act_scale (self , x ):
54
- return x .abs ().view (- 1 , x .shape [- 1 ]).mean (0 )
63
+ batch_means = []
64
+ b_num = x .shape [0 ] // self ._bs
65
+ for num in range (b_num ):
66
+ batch_x = x [num * self ._bs :(num + 1 ) * self ._bs ]
67
+ batch_mean = batch_x .abs ().view (- 1 , batch_x .shape [- 1 ]).mean (0 )
68
+ batch_means .append (batch_mean )
69
+ final_mean = sum (batch_means ) / len (batch_means )
70
+ return final_mean
71
+
72
+ @torch .no_grad ()
73
+ def get_scales (self , prev_op , x , w_max , is_gqa , ratio ):
74
+ if is_gqa :
75
+ x_tmp = prev_op (x )
76
+ w_tmp = self .get_weight_scale ({'prev_op' : prev_op })
77
+ else :
78
+ x_tmp = x
79
+ w_tmp = w_max
80
+
81
+ x_tmp = self .get_act_scale (x_tmp )
82
+
83
+ if self .trans_version == 'v1' :
84
+ scales = (
85
+ (x_tmp .pow (ratio ) / w_tmp .pow (1 - ratio ))
86
+ .clamp (min = 1e-4 )
87
+ .view (- 1 )
88
+ )
89
+ elif self .trans_version == 'v2' :
90
+ scales = x_tmp .pow (ratio ).clamp (min = 1e-4 ).view (- 1 )
91
+
92
+ scales = scales / (scales .max () * scales .min ()).sqrt ()
93
+ return scales
94
+
95
+ def inspect_module_forward (self , x , inspect_module , kwargs ):
96
+ outs = []
97
+ b_num = x .shape [0 ] // self ._bs
98
+ for num in range (b_num ):
99
+ _x = x [num * self ._bs :(num + 1 ) * self ._bs ]
100
+ out = inspect_module (_x , ** kwargs )
101
+ if isinstance (out , tuple ):
102
+ out = out [0 ]
103
+ outs .append (out )
104
+ return torch .cat (outs , dim = 0 )
55
105
56
106
@torch .no_grad ()
57
107
def get_original_out (self , x , inspect_module , subset_kwargs ):
58
108
with torch .no_grad ():
59
- org_out = inspect_module (x , ** subset_kwargs )
60
- if isinstance (org_out , tuple ):
61
- org_out = org_out [0 ]
109
+ org_out = self .inspect_module_forward (x , inspect_module , subset_kwargs )
62
110
return org_out
63
111
112
+ def calculate_loss (self , org_out , out ):
113
+ total_loss = 0.0
114
+ b_num = org_out .shape [0 ] // self ._bs
115
+ for num in range (b_num ):
116
+ _org_out = org_out [num * self ._bs :(num + 1 ) * self ._bs ]
117
+ _out = out [num * self ._bs :(num + 1 ) * self ._bs ]
118
+ single_loss = (_org_out - _out ).float ().pow (2 ).mean ().item ()
119
+ total_loss += single_loss
120
+ return total_loss / b_num
121
+
64
122
@torch .no_grad ()
65
- def search_scale_subset (self , layers_dict , input , inspect_module , subset_kwargs ):
123
+ def search_scale_subset (
124
+ self ,
125
+ prev_op ,
126
+ layers_dict ,
127
+ input ,
128
+ inspect_module ,
129
+ is_gqa ,
130
+ subset_kwargs
131
+ ):
132
+
133
+ if self .awq_bs is None :
134
+ self ._bs = input [0 ].shape [0 ]
135
+ else :
136
+ self ._bs = self .awq_bs
137
+
66
138
w_max = self .get_weight_scale (layers_dict )
67
139
# grid search for ratio
68
140
best_error = float ('inf' )
@@ -89,18 +161,10 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
89
161
x_max = self .get_act_scale (x )
90
162
91
163
ratio = n * 1 / n_grid
92
- if self .trans_version == 'v1' :
93
- scales = (
94
- (x_max .pow (ratio ) / w_max .pow (1 - ratio ))
95
- .clamp (min = 1e-4 )
96
- .view (- 1 )
97
- )
98
- elif self .trans_version == 'v2' :
99
- scales = x_max .pow (ratio ).clamp (min = 1e-4 ).view (- 1 )
100
- scales = scales / (scales .max () * scales .min ()).sqrt ()
164
+ scales = self .get_scales (prev_op , x , w_max , is_gqa , ratio )
101
165
for layer_name in layers_dict :
102
166
fc = layers_dict [layer_name ]
103
- fc .weight . mul_ ( scales . view ( 1 , - 1 ) )
167
+ fc .weight = self . scaling_weight ( fc . weight , scales , is_gqa )
104
168
105
169
fc .weight .data = get_wquantizer (
106
170
self .block_idx ,
@@ -110,31 +174,39 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
110
174
self .wquantizer ,
111
175
).fake_quant_weight_dynamic (fc .weight .data )
112
176
113
- x_tmp = x / scales .view (1 , - 1 )
177
+ del x_max
178
+ gc .collect ()
179
+ torch .cuda .empty_cache ()
180
+
181
+ x_tmp = self .scaling_input (x , scales , is_gqa )
182
+
114
183
if not check_w_only (
115
184
self .block_idx ,
116
185
list (layers_dict .keys ())[0 ],
117
186
self .mix_bits_map ,
118
187
self .quantizer_mix_bits ,
119
188
self .w_only ,
120
189
):
121
- x_tmp = get_aquantizer (
122
- self .block_idx ,
123
- list (layers_dict .keys ())[0 ],
124
- self .mix_bits_map ,
125
- self .quantizer_mix_bits ,
126
- self .aquantizer ,
127
- ).fake_quant_act_dynamic (x_tmp )
128
- out = inspect_module (x_tmp , ** kwargs )
129
-
130
- if isinstance (out , tuple ):
131
- out = out [0 ]
190
+ outs = []
191
+ for i in range (x_tmp .shape [0 ]):
192
+ _x = x_tmp [i ]
193
+ _x = get_aquantizer (
194
+ self .block_idx ,
195
+ list (layers_dict .keys ())[0 ],
196
+ self .mix_bits_map ,
197
+ self .quantizer_mix_bits ,
198
+ self .aquantizer ,
199
+ ).fake_quant_act_dynamic (_x )
200
+ outs .append (_x )
201
+ x_tmp = torch .stack (outs )
202
+
203
+ out = self .inspect_module_forward (x_tmp , inspect_module , kwargs )
132
204
133
205
if self .padding_mask and org_out .shape [1 ] == self .padding_mask [i ].shape [- 1 ]:
134
206
org_out = org_out * self .padding_mask [i ].unsqueeze (dim = - 1 ).to (org_out .device ) # noqa
135
207
out = out * self .padding_mask [i ].unsqueeze (dim = - 1 ).to (out .device )
136
208
137
- loss = (org_out - out ). float (). pow ( 2 ). mean (). item ( )
209
+ loss = self . calculate_loss (org_out , out )
138
210
139
211
if len (input ) == 1 :
140
212
n_samples = x .shape [0 ]
@@ -149,6 +221,11 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
149
221
best_error = loss_mean
150
222
best_scales = scales_mean
151
223
224
+ del org_out
225
+ del out
226
+ gc .collect ()
227
+ torch .cuda .empty_cache ()
228
+
152
229
# Synchronize across ranks
153
230
best_error_tensor = torch .tensor ([best_error ], device = 'cuda' )
154
231
dist .all_reduce (best_error_tensor , op = dist .ReduceOp .MIN )
@@ -248,15 +325,28 @@ def subset_transform(
248
325
and prev_op [0 ].out_features != layers [0 ].in_features * 2
249
326
and prev_op [0 ].out_features != layers [0 ].in_features
250
327
):
251
- logger .info ('Cannot apply scale. Do not transform this subset.' )
252
- return
328
+
329
+ if self .has_gqa :
330
+ is_gqa = True
331
+ input_keys = list (input_feat .keys ())
332
+ input_name = input_keys [input_keys .index (input_name ) - 1 ]
333
+ else :
334
+ logger .info ('Cannot apply scale. Do not transform this subset.' )
335
+ return
336
+ else :
337
+ is_gqa = False
253
338
254
339
scale = self .search_scale_subset (
255
- layers_dict , input_feat [input_name ], inspect_module , subset_kwargs
340
+ prev_op [0 ],
341
+ layers_dict ,
342
+ input_feat [input_name ],
343
+ inspect_module ,
344
+ is_gqa ,
345
+ subset_kwargs
256
346
)
257
347
258
348
self .apply_scale (scale , prev_op , layers )
259
- self .update_input_feat (scale , input_feat , layers_dict )
349
+ self .update_input_feat (scale , input_feat , layers_dict , is_gqa )
260
350
261
351
if self .save_scale :
262
352
for n in layers_dict :
0 commit comments