@@ -31,8 +31,8 @@ def scaling_weight(self, w, scales, is_gqa):
31
31
scales_tmp = self .repeat_gqa_scales (scales )
32
32
else :
33
33
scales_tmp = scales
34
- w_tmp = w .mul_ (scales_tmp .view (1 , - 1 ))
35
- return w_tmp
34
+ w .mul_ (scales_tmp .view (1 , - 1 ))
35
+ return w
36
36
37
37
@torch .no_grad ()
38
38
def get_weight_scale (self , layers_dict ):
@@ -60,14 +60,17 @@ def get_weight_scale(self, layers_dict):
60
60
return scale
61
61
62
62
def get_act_scale (self , x ):
63
- batch_means = []
64
- b_num = x .shape [0 ] // self ._bs
65
- for num in range (b_num ):
66
- batch_x = x [num * self ._bs :(num + 1 ) * self ._bs ]
67
- batch_mean = batch_x .abs ().view (- 1 , batch_x .shape [- 1 ]).mean (0 )
68
- batch_means .append (batch_mean )
69
- final_mean = sum (batch_means ) / len (batch_means )
70
- return final_mean
63
+ if x .shape [0 ] == self ._bs :
64
+ return x .abs ().view (- 1 , x .shape [- 1 ]).mean (0 )
65
+ else :
66
+ batch_means = []
67
+ b_num = x .shape [0 ] // self ._bs
68
+ for num in range (b_num ):
69
+ batch_x = x [num * self ._bs :(num + 1 ) * self ._bs ]
70
+ batch_mean = batch_x .abs ().view (- 1 , batch_x .shape [- 1 ]).mean (0 )
71
+ batch_means .append (batch_mean )
72
+ final_mean = sum (batch_means ) / len (batch_means )
73
+ return final_mean
71
74
72
75
@torch .no_grad ()
73
76
def get_scales (self , prev_op , x , w_max , is_gqa , ratio ):
@@ -93,15 +96,22 @@ def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
93
96
return scales
94
97
95
98
def inspect_module_forward (self , x , inspect_module , kwargs ):
96
- outs = []
97
- b_num = x .shape [0 ] // self ._bs
98
- for num in range (b_num ):
99
- _x = x [num * self ._bs :(num + 1 ) * self ._bs ]
100
- out = inspect_module (_x , ** kwargs )
101
- if isinstance (out , tuple ):
102
- out = out [0 ]
103
- outs .append (out )
104
- return torch .cat (outs , dim = 0 )
99
+ if self ._bs == x .shape [0 ]:
100
+ with torch .no_grad ():
101
+ out = inspect_module (x , ** kwargs )
102
+ if isinstance (out , tuple ):
103
+ out = out [0 ]
104
+ return out
105
+ else :
106
+ outs = []
107
+ b_num = x .shape [0 ] // self ._bs
108
+ for num in range (b_num ):
109
+ _x = x [num * self ._bs :(num + 1 ) * self ._bs ]
110
+ out = inspect_module (_x , ** kwargs )
111
+ if isinstance (out , tuple ):
112
+ out = out [0 ]
113
+ outs .append (out )
114
+ return torch .cat (outs , dim = 0 )
105
115
106
116
@torch .no_grad ()
107
117
def get_original_out (self , x , inspect_module , subset_kwargs ):
@@ -110,14 +120,53 @@ def get_original_out(self, x, inspect_module, subset_kwargs):
110
120
return org_out
111
121
112
122
def calculate_loss (self , org_out , out ):
113
- total_loss = 0.0
114
- b_num = org_out .shape [0 ] // self ._bs
115
- for num in range (b_num ):
116
- _org_out = org_out [num * self ._bs :(num + 1 ) * self ._bs ]
117
- _out = out [num * self ._bs :(num + 1 ) * self ._bs ]
118
- single_loss = (_org_out - _out ).float ().pow (2 ).mean ().item ()
119
- total_loss += single_loss
120
- return total_loss / b_num
123
+ if out .shape [0 ] == self ._bs :
124
+ return (org_out - out ).float ().pow (2 ).mean ().item ()
125
+ else :
126
+ total_loss = 0.0
127
+ b_num = org_out .shape [0 ] // self ._bs
128
+ for num in range (b_num ):
129
+ _org_out = org_out [num * self ._bs :(num + 1 ) * self ._bs ]
130
+ _out = out [num * self ._bs :(num + 1 ) * self ._bs ]
131
+ single_loss = (_org_out - _out ).float ().pow (2 ).mean ().item ()
132
+ total_loss += single_loss
133
+ return total_loss / b_num
134
+
135
+ def fake_quantize_weight (self , weight , scales , is_gqa , layer_name ):
136
+ weight = self .scaling_weight (weight , scales , is_gqa )
137
+ weight .data = get_wquantizer (
138
+ self .block_idx ,
139
+ layer_name ,
140
+ self .mix_bits_map ,
141
+ self .quantizer_mix_bits ,
142
+ self .wquantizer ,
143
+ ).fake_quant_weight_dynamic (weight .data )
144
+
145
+ return weight
146
+
147
+ def fake_quantize_input (self , x_tmp , layers_dict ):
148
+ if self ._bs == x_tmp .shape [0 ]:
149
+ x_tmp = get_aquantizer (
150
+ self .block_idx ,
151
+ list (layers_dict .keys ())[0 ],
152
+ self .mix_bits_map ,
153
+ self .quantizer_mix_bits ,
154
+ self .aquantizer ,
155
+ ).fake_quant_act_dynamic (x_tmp )
156
+ else :
157
+ outs = []
158
+ for i in range (x_tmp .shape [0 ]):
159
+ _x = x_tmp [i ]
160
+ _x = get_aquantizer (
161
+ self .block_idx ,
162
+ list (layers_dict .keys ())[0 ],
163
+ self .mix_bits_map ,
164
+ self .quantizer_mix_bits ,
165
+ self .aquantizer ,
166
+ ).fake_quant_act_dynamic (_x )
167
+ outs .append (_x )
168
+ x_tmp = torch .stack (outs )
169
+ return x_tmp
121
170
122
171
@torch .no_grad ()
123
172
def search_scale_subset (
@@ -158,25 +207,12 @@ def search_scale_subset(
158
207
else :
159
208
org_out = self .get_original_out (x , inspect_module , kwargs )
160
209
org_out_dict [i ] = org_out
161
- x_max = self .get_act_scale (x )
162
210
163
211
ratio = n * 1 / n_grid
164
212
scales = self .get_scales (prev_op , x , w_max , is_gqa , ratio )
165
213
for layer_name in layers_dict :
166
214
fc = layers_dict [layer_name ]
167
- fc .weight = self .scaling_weight (fc .weight , scales , is_gqa )
168
-
169
- fc .weight .data = get_wquantizer (
170
- self .block_idx ,
171
- layer_name ,
172
- self .mix_bits_map ,
173
- self .quantizer_mix_bits ,
174
- self .wquantizer ,
175
- ).fake_quant_weight_dynamic (fc .weight .data )
176
-
177
- del x_max
178
- gc .collect ()
179
- torch .cuda .empty_cache ()
215
+ fc .weight = self .fake_quantize_weight (fc .weight , scales , is_gqa , layer_name )
180
216
181
217
x_tmp = self .scaling_input (x , scales , is_gqa )
182
218
@@ -187,18 +223,7 @@ def search_scale_subset(
187
223
self .quantizer_mix_bits ,
188
224
self .w_only ,
189
225
):
190
- outs = []
191
- for i in range (x_tmp .shape [0 ]):
192
- _x = x_tmp [i ]
193
- _x = get_aquantizer (
194
- self .block_idx ,
195
- list (layers_dict .keys ())[0 ],
196
- self .mix_bits_map ,
197
- self .quantizer_mix_bits ,
198
- self .aquantizer ,
199
- ).fake_quant_act_dynamic (_x )
200
- outs .append (_x )
201
- x_tmp = torch .stack (outs )
226
+ x_tmp = self .fake_quantize_input (x_tmp , layers_dict )
202
227
203
228
out = self .inspect_module_forward (x_tmp , inspect_module , kwargs )
204
229
0 commit comments