@@ -95,8 +95,13 @@ def load_test_model(opt, device_id=0, model_path=None):
9595
9696    model_opt  =  ArgumentParser .ckpt_model_opts (checkpoint ["opt" ])
9797
98-     model_opt .quant_layers  =  opt .quant_layers 
99-     model_opt .quant_type  =  opt .quant_type 
98+     if  hasattr (model_opt , "quant_type" ) and  model_opt .quant_type  not  in 
99+         "llm_awq" ,
100+         "aawq_gemm" ,
101+         "aawq_gemv" ,
102+     ]:
103+         model_opt .quant_layers  =  opt .quant_layers 
104+         model_opt .quant_type  =  opt .quant_type 
100105
101106    if  opt .world_size  >  1  and  opt .parallel_mode  ==  "tensor_parallel" :
102107        model_opt .world_size  =  opt .world_size 
@@ -304,6 +309,21 @@ def build_base_model(model_opt, vocabs):
304309            model  =  replace_bnb_linear (
305310                model , module_to_convert = nonlora_to_quant , q_type = model_opt .quant_type 
306311            )
312+         elif  model_opt .quant_type  in  ["llm_awq" , "aawq_gemm" , "aawq_gemv" ]:
313+             logger .info (
314+                 "%s compression of layer %s"  %  (model_opt .quant_type , nonlora_to_quant )
315+             )
316+             try :
317+                 from  onmt .modules .awq_linear  import  replace_awq_linear 
318+             except  ImportError :
319+                 raise  ImportError ("Install llm-awq/AutoAWQ to use awq quantized model" )
320+             model  =  replace_awq_linear (
321+                 model ,
322+                 module_to_convert = nonlora_to_quant ,
323+                 w_bit = model_opt .w_bit ,
324+                 group_size = model_opt .group_size ,
325+                 q_type = model_opt .quant_type ,
326+             )
307327        else :
308328            logger .info ("compression type %s not supported."  %  model_opt .quant_type )
309329
0 commit comments