@@ -58,7 +58,7 @@ def a_qdq(self, act, module, aquantizer, input_index=0):
5858
5959 def get_replacement_params (self , mode = 'fake_quant' , w_only = False , name = None ):
6060 params_dict = {}
61- if mode == 'fake_quant' :
61+ if mode in [ 'fake_quant' , 'fake_quant_wo_kv' ] :
6262 if not self .mix_bits :
6363 params_dict ['a_qdq' ] = (
6464 partial (self .a_qdq , aquantizer = self .aquantizer )
@@ -229,17 +229,16 @@ def set_quant_config(self):
229229 # set kv cache quant config
230230 if 'kvcache' in self .quant_config :
231231 self .quant_config ['kvcache' ]['static' ] = self .act_static
232+ kv_special_cfg = self .quant_config ['kvcache' ].get ('special' , {})
233+ logger .info (kv_special_cfg )
234+ act_static_cfg = {}
232235 if self .act_static :
233- self .kv_module = KV_REGISTRY [self .quant_config ['kvcache' ]['method' ]](
234- self .quant_type , self .quant_config ['kvcache' ],
235- self .model .model_config .num_hidden_layers , self .config .calib .n_samples ,
236- self .config .calib .bs
237- )
238- else :
239- self .kv_module = KV_REGISTRY [self .quant_config ['kvcache' ]['method' ]](
240- self .quant_type , self .quant_config ['kvcache' ],
241- self .model .model_config .num_hidden_layers
242- )
236+ act_static_cfg .update (self .config .calib .n_sample )
237+ act_static_cfg .update (self .config .calib .bs )
238+ self .kv_module = KV_REGISTRY [self .quant_config ['kvcache' ]['method' ]](
239+ self .quant_type , self .quant_config ['kvcache' ],
240+ self .model .model_config .num_hidden_layers , ** kv_special_cfg , ** act_static_cfg
241+ )
243242 self .quant_kvcache = True
244243 self .model .kvcache_buffer .append (self .kv_module )
245244 else :
@@ -860,6 +859,7 @@ def deploy(self, quant_format, keep_device=False):
860859 module_mapping = {
861860 'origin_float' : OriginFloatLinear ,
862861 'fake_quant' : EffcientFakeQuantLinear ,
862+ 'fake_quant_wo_kv' : EffcientFakeQuantLinear ,
863863 }
864864 module_mapping .update (_REALQUANT_LINEAR_MAP_ )
865865
@@ -884,10 +884,12 @@ def deploy(self, quant_format, keep_device=False):
884884 self .set_non_linear_mode (quant_format , self .model .model , False )
885885
886886 if self .quant_kvcache :
887- if quant_format == 'transformed' :
888- self .kv_module .transformed = True
887+ if quant_format == 'origin_float' :
888+ self .kv_module .use_org_kv = True
889+ elif quant_format == 'fake_quant_wo_kv' :
890+ self .kv_module .use_org_kv = True
889891 elif quant_format == 'fake_quant' :
890- self .kv_module .transformed = False
892+ self .kv_module .use_org_kv = False
891893 if self .act_static :
892894 self .kv_module .calib = False
893895
0 commit comments