@@ -31,8 +31,8 @@ def scaling_weight(self, w, scales, is_gqa):
3131 scales_tmp = self .repeat_gqa_scales (scales )
3232 else :
3333 scales_tmp = scales
34- w_tmp = w .mul_ (scales_tmp .view (1 , - 1 ))
35- return w_tmp
34+ w .mul_ (scales_tmp .view (1 , - 1 ))
35+ return w
3636
3737 @torch .no_grad ()
3838 def get_weight_scale (self , layers_dict ):
@@ -60,14 +60,17 @@ def get_weight_scale(self, layers_dict):
6060 return scale
6161
6262 def get_act_scale (self , x ):
63- batch_means = []
64- b_num = x .shape [0 ] // self ._bs
65- for num in range (b_num ):
66- batch_x = x [num * self ._bs :(num + 1 ) * self ._bs ]
67- batch_mean = batch_x .abs ().view (- 1 , batch_x .shape [- 1 ]).mean (0 )
68- batch_means .append (batch_mean )
69- final_mean = sum (batch_means ) / len (batch_means )
70- return final_mean
63+ if x .shape [0 ] == self ._bs :
64+ return x .abs ().view (- 1 , x .shape [- 1 ]).mean (0 )
65+ else :
66+ batch_means = []
67+ b_num = x .shape [0 ] // self ._bs
68+ for num in range (b_num ):
69+ batch_x = x [num * self ._bs :(num + 1 ) * self ._bs ]
70+ batch_mean = batch_x .abs ().view (- 1 , batch_x .shape [- 1 ]).mean (0 )
71+ batch_means .append (batch_mean )
72+ final_mean = sum (batch_means ) / len (batch_means )
73+ return final_mean
7174
7275 @torch .no_grad ()
7376 def get_scales (self , prev_op , x , w_max , is_gqa , ratio ):
@@ -93,15 +96,22 @@ def get_scales(self, prev_op, x, w_max, is_gqa, ratio):
9396 return scales
9497
9598 def inspect_module_forward (self , x , inspect_module , kwargs ):
96- outs = []
97- b_num = x .shape [0 ] // self ._bs
98- for num in range (b_num ):
99- _x = x [num * self ._bs :(num + 1 ) * self ._bs ]
100- out = inspect_module (_x , ** kwargs )
101- if isinstance (out , tuple ):
102- out = out [0 ]
103- outs .append (out )
104- return torch .cat (outs , dim = 0 )
99+ if self ._bs == x .shape [0 ]:
100+ with torch .no_grad ():
101+ out = inspect_module (x , ** kwargs )
102+ if isinstance (out , tuple ):
103+ out = out [0 ]
104+ return out
105+ else :
106+ outs = []
107+ b_num = x .shape [0 ] // self ._bs
108+ for num in range (b_num ):
109+ _x = x [num * self ._bs :(num + 1 ) * self ._bs ]
110+ out = inspect_module (_x , ** kwargs )
111+ if isinstance (out , tuple ):
112+ out = out [0 ]
113+ outs .append (out )
114+ return torch .cat (outs , dim = 0 )
105115
106116 @torch .no_grad ()
107117 def get_original_out (self , x , inspect_module , subset_kwargs ):
@@ -110,14 +120,53 @@ def get_original_out(self, x, inspect_module, subset_kwargs):
110120 return org_out
111121
112122 def calculate_loss (self , org_out , out ):
113- total_loss = 0.0
114- b_num = org_out .shape [0 ] // self ._bs
115- for num in range (b_num ):
116- _org_out = org_out [num * self ._bs :(num + 1 ) * self ._bs ]
117- _out = out [num * self ._bs :(num + 1 ) * self ._bs ]
118- single_loss = (_org_out - _out ).float ().pow (2 ).mean ().item ()
119- total_loss += single_loss
120- return total_loss / b_num
123+ if out .shape [0 ] == self ._bs :
124+ return (org_out - out ).float ().pow (2 ).mean ().item ()
125+ else :
126+ total_loss = 0.0
127+ b_num = org_out .shape [0 ] // self ._bs
128+ for num in range (b_num ):
129+ _org_out = org_out [num * self ._bs :(num + 1 ) * self ._bs ]
130+ _out = out [num * self ._bs :(num + 1 ) * self ._bs ]
131+ single_loss = (_org_out - _out ).float ().pow (2 ).mean ().item ()
132+ total_loss += single_loss
133+ return total_loss / b_num
134+
135+ def fake_quantize_weight (self , weight , scales , is_gqa , layer_name ):
136+ weight = self .scaling_weight (weight , scales , is_gqa )
137+ weight .data = get_wquantizer (
138+ self .block_idx ,
139+ layer_name ,
140+ self .mix_bits_map ,
141+ self .quantizer_mix_bits ,
142+ self .wquantizer ,
143+ ).fake_quant_weight_dynamic (weight .data )
144+
145+ return weight
146+
147+ def fake_quantize_input (self , x_tmp , layers_dict ):
148+ if self ._bs == x_tmp .shape [0 ]:
149+ x_tmp = get_aquantizer (
150+ self .block_idx ,
151+ list (layers_dict .keys ())[0 ],
152+ self .mix_bits_map ,
153+ self .quantizer_mix_bits ,
154+ self .aquantizer ,
155+ ).fake_quant_act_dynamic (x_tmp )
156+ else :
157+ outs = []
158+ for i in range (x_tmp .shape [0 ]):
159+ _x = x_tmp [i ]
160+ _x = get_aquantizer (
161+ self .block_idx ,
162+ list (layers_dict .keys ())[0 ],
163+ self .mix_bits_map ,
164+ self .quantizer_mix_bits ,
165+ self .aquantizer ,
166+ ).fake_quant_act_dynamic (_x )
167+ outs .append (_x )
168+ x_tmp = torch .stack (outs )
169+ return x_tmp
121170
122171 @torch .no_grad ()
123172 def search_scale_subset (
@@ -158,25 +207,12 @@ def search_scale_subset(
158207 else :
159208 org_out = self .get_original_out (x , inspect_module , kwargs )
160209 org_out_dict [i ] = org_out
161- x_max = self .get_act_scale (x )
162210
163211 ratio = n * 1 / n_grid
164212 scales = self .get_scales (prev_op , x , w_max , is_gqa , ratio )
165213 for layer_name in layers_dict :
166214 fc = layers_dict [layer_name ]
167- fc .weight = self .scaling_weight (fc .weight , scales , is_gqa )
168-
169- fc .weight .data = get_wquantizer (
170- self .block_idx ,
171- layer_name ,
172- self .mix_bits_map ,
173- self .quantizer_mix_bits ,
174- self .wquantizer ,
175- ).fake_quant_weight_dynamic (fc .weight .data )
176-
177- del x_max
178- gc .collect ()
179- torch .cuda .empty_cache ()
215+ fc .weight = self .fake_quantize_weight (fc .weight , scales , is_gqa , layer_name )
180216
181217 x_tmp = self .scaling_input (x , scales , is_gqa )
182218
@@ -187,18 +223,7 @@ def search_scale_subset(
187223 self .quantizer_mix_bits ,
188224 self .w_only ,
189225 ):
190- outs = []
191- for i in range (x_tmp .shape [0 ]):
192- _x = x_tmp [i ]
193- _x = get_aquantizer (
194- self .block_idx ,
195- list (layers_dict .keys ())[0 ],
196- self .mix_bits_map ,
197- self .quantizer_mix_bits ,
198- self .aquantizer ,
199- ).fake_quant_act_dynamic (_x )
200- outs .append (_x )
201- x_tmp = torch .stack (outs )
226+ x_tmp = self .fake_quantize_input (x_tmp , layers_dict )
202227
203228 out = self .inspect_module_forward (x_tmp , inspect_module , kwargs )
204229
0 commit comments