@@ -390,6 +390,14 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
390
390
fx_graph_module = torch .export .export (
391
391
self .llama_graph_module , self .inputs , strict = True
392
392
).module ()
393
+
394
+ if QuantDtype == QuantDtype .use_16a4w_block :
395
+ conv_nodes = [
396
+ n for n in fx_graph_module .graph .nodes if "conv" in n .name
397
+ ]
398
+ block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
399
+ quantizer .set_block_size_map (block_size_map )
400
+
393
401
fx_graph_module = prepare_pt2e (fx_graph_module , quantizer )
394
402
395
403
logging .info ("Quantizing the model..." )
@@ -574,13 +582,14 @@ def permute(w, heads):
574
582
fixed_point_type ["kv_type" ] = torch .uint8
575
583
if args .ptq == "8a8w" :
576
584
fixed_point_type ["io_type" ] = torch .uint8
577
- elif args .ptq == "16a4w" :
585
+ elif args .ptq in ( "16a4w" , "16a4w_block" ) :
578
586
fixed_point_type ["io_type" ] = torch .uint16
579
587
else :
580
588
assert args .ptq in [
581
589
"8a8w" ,
582
590
"16a4w" ,
583
- ], f"No support for quant type { args .ptq } . Support 8a8w and 16a4w."
591
+ "16a4w_block" ,
592
+ ], f"No support for quant type { args .ptq } . Support 8a8w, 16a4w and 16a4w_block."
584
593
quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
585
594
586
595
assert args .tokenizer_model is not None , "Need tokenizer model for calibration"
@@ -954,7 +963,7 @@ def _build_parser():
954
963
parser .add_argument (
955
964
"-P" ,
956
965
"--ptq" ,
957
- help = "If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w ." ,
966
+ help = "If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block ." ,
958
967
type = str ,
959
968
)
960
969
0 commit comments