From 7d3e96a39d2d823a74219471e24edad7fc1bf4a4 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 17 Oct 2024 00:46:31 -0700 Subject: [PATCH 1/2] Adding torchao apis to gpt-fast Summary: adding torchao apis to gpt-fast and some minor tweaks Test Plan: (in progress) export MODEL_REPO=meta-llama/Meta-Llama-3-8B python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode torchao-int8 python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int8.pth --compile python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int8.pth python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int8.pth --tasks wikitext --compile python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode torchao-int4-hqq python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int4-hqq.pth --compile python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int4-hqq.pth python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int4-hqq.pth --tasks wikitext --compile python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode torchao-int4 python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int4.pth --compile python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int4.pth python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_torchao-int4.pth --tasks wikitext --compile python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --compile python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --tasks wikitext --compile python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --tasks wikitext --compile Reviewers: Subscribers: Tasks: Tags: --- eval.py | 6 +++++- generate.py | 9 ++++++--- quantize.py | 34 +++++++++++++++++++++------------- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/eval.py b/eval.py index 712ee181..c2dd6976 100644 --- a/eval.py +++ b/eval.py @@ -92,7 +92,11 @@ def __init__( tokenizer, max_seq_length: Optional[int]=None, ): - super().__init__() + try: + super().__init__() + except TypeError: + # lm_eval 0.4.2 removed the default init + super().__init__("gpt2", device="cuda") self._model = model self._tokenizer = tokenizer self._device = torch.device('cuda') diff --git a/generate.py b/generate.py index b7a4c113..51ba8600 100644 --- a/generate.py +++ b/generate.py @@ -221,13 +221,16 @@ def _load_model(checkpoint_path, device, precision, use_tp): with torch.device('meta'): model = Transformer.from_name(checkpoint_path.parent.name) - if "int8" in str(checkpoint_path): + # don't have to transform the model when using torchao apis + is_torchao = 'torchao-' in str(checkpoint_path) + + if "int8" in str(checkpoint_path) and not is_torchao: print("Using int8 weight-only quantization!") from quantize import WeightOnlyInt8QuantHandler simple_quantizer = WeightOnlyInt8QuantHandler(model) model = simple_quantizer.convert_for_runtime() - if "int4" in str(checkpoint_path): + if "int4" in str(checkpoint_path) and not is_torchao: print("Using int4 weight-only quantization!") path_comps = checkpoint_path.name.split(".") groupsize = int(path_comps[-2][1:]) @@ -235,7 +238,7 @@ def _load_model(checkpoint_path, device, precision, use_tp): simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) model = simple_quantizer.convert_for_runtime() - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=not is_torchao) if "model" in checkpoint and "stories" in str(checkpoint_path): checkpoint = checkpoint["model"] model.load_state_dict(checkpoint, assign=True) diff --git a/quantize.py b/quantize.py index fb566421..63f29ada 100644 --- a/quantize.py +++ b/quantize.py @@ -554,22 +554,33 @@ def quantize( model.load_state_dict(checkpoint, assign=True) model = model.to(dtype=precision, device=device) - if mode == 'int8': + dir_name = checkpoint_path.parent + base_name = checkpoint_path.name + + if 'torchao-int4' in mode: + import torchao + from torchao.quantization import (quantize_, int4_weight_only) + use_hqq = 'hqq' in mode + print(f"Quantizing model weights for int4 weight-only symmetric per-channel quantization {'with hqq' if use_hqq else ''}") + quantize_(model, int4_weight_only(group_size=groupsize, use_hqq=use_hqq), device='cuda') + quantized_state_dict = model.state_dict() + new_base_name = base_name.replace('.pth', f'{label}{mode}.pth') + elif 'torchao-int8' in mode: + import torchao + from torchao.quantization import (quantize_, int8_weight_only) + print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") + quantize_(model, int8_weight_only()) + quantized_state_dict = model.state_dict() + new_base_name = base_name.replace('.pth', f'{label}{mode}.pth') + elif mode == 'int8': print("Quantizing model weights for int8 weight-only symmetric per-channel quantization") quant_handler = WeightOnlyInt8QuantHandler(model) quantized_state_dict = quant_handler.create_quantized_state_dict() - - dir_name = checkpoint_path.parent - base_name = checkpoint_path.name new_base_name = base_name.replace('.pth', f'{label}int8.pth') - elif mode == 'int4': print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization") quant_handler = WeightOnlyInt4QuantHandler(model, groupsize) quantized_state_dict = quant_handler.create_quantized_state_dict() - - dir_name = checkpoint_path.parent - base_name = checkpoint_path.name new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth") elif mode == 'int4-gptq': @@ -590,12 +601,9 @@ def quantize( calibration_seq_length, pad_calibration_inputs ) - - dir_name = checkpoint_path.parent - base_name = checkpoint_path.name new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth") else: - raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]") + raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq, torchao-int4, torchao-int8, torchao-int4-hqq]") quantize_path = dir_name / new_base_name print(f"Writing quantized weights to {quantize_path}") @@ -608,7 +616,7 @@ def quantize( import argparse parser = argparse.ArgumentParser(description='Quantize a model.') parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Path to the model checkpoint to be quantized.') - parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq'], help='type of quantization to perform') + parser.add_argument('--mode', '-q', type=str, default='int8', choices=['int8', 'int4', 'int4-gptq', 'torchao-int4', 'torchao-int8', 'torchao-int4-hqq'], help='type of quantization to perform') parser.add_argument('--groupsize', type=int, default=32, help='Group size for int4 quantization.') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') parser.add_argument('--calibration_limit', type=int, default=1000, help='number of samples to use for gptq calibration') From 7144ffbe594f42fdff34732a9ee2c4f0741f12fc Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 17 Oct 2024 17:13:48 -0700 Subject: [PATCH 2/2] Adding info to readme Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 33d7216e..86e663dd 100644 --- a/README.md +++ b/README.md @@ -180,6 +180,17 @@ To run with int4, just pass the int4 checkpoint to generate.py. python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile ``` +### TorchAO Quantization APIs +There are also options to use TorchAO apis with quantize.py using the torchao-int4, torchao-int8 and torchao-int4-hqq options +To generate this version of the model +```bash +# Spits out model at checkpoints/$MODEL_REPO/model_torchao-int4.pth +python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode torchao-int4-hqq --groupsize 32 +``` +In addition to adding the hqq option for int4 quantization, the primary difference between the TorchAO quantization apis and the gpt-fast ones are that the checkpoints saved using the TorchAO apis +can be loaded directly, rather than requiring + + ## Speculative Sampling To generate with speculative sampling (DRAFT_MODEL_REPO should point to a smaller model compared with MODEL_REPO).