Skip to content

Commit e6c7b30

Browse files
chunit-quicChun-I Tsai
and
Chun-I Tsai
authored
Qualcomm AI Engine Direct - Add block quantization to llama (#10225)
- Add CLI argument to use block quantization for llama Co-authored-by: Chun-I Tsai <[email protected]>
1 parent 8b4500b commit e6c7b30

File tree

1 file changed

+12
-3
lines changed
  • examples/qualcomm/oss_scripts/llama

1 file changed

+12
-3
lines changed

examples/qualcomm/oss_scripts/llama/llama.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,14 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
390390
fx_graph_module = torch.export.export(
391391
self.llama_graph_module, self.inputs, strict=True
392392
).module()
393+
394+
if QuantDtype == QuantDtype.use_16a4w_block:
395+
conv_nodes = [
396+
n for n in fx_graph_module.graph.nodes if "conv" in n.name
397+
]
398+
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
399+
quantizer.set_block_size_map(block_size_map)
400+
393401
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
394402

395403
logging.info("Quantizing the model...")
@@ -574,13 +582,14 @@ def permute(w, heads):
574582
fixed_point_type["kv_type"] = torch.uint8
575583
if args.ptq == "8a8w":
576584
fixed_point_type["io_type"] = torch.uint8
577-
elif args.ptq == "16a4w":
585+
elif args.ptq in ("16a4w", "16a4w_block"):
578586
fixed_point_type["io_type"] = torch.uint16
579587
else:
580588
assert args.ptq in [
581589
"8a8w",
582590
"16a4w",
583-
], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w."
591+
"16a4w_block",
592+
], f"No support for quant type {args.ptq}. Support 8a8w, 16a4w and 16a4w_block."
584593
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
585594

586595
assert args.tokenizer_model is not None, "Need tokenizer model for calibration"
@@ -954,7 +963,7 @@ def _build_parser():
954963
parser.add_argument(
955964
"-P",
956965
"--ptq",
957-
help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.",
966+
help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w, 16a4w and 16a4w_block.",
958967
type=str,
959968
)
960969

0 commit comments

Comments
 (0)