|
| 1 | +Quick Start Guide |
| 2 | +----------------- |
| 3 | + |
| 4 | +In this quick start guide, we will explore how to perform basic quantization using torchao. |
| 5 | +First, install the latest stable torchao release:: |
| 6 | + |
| 7 | + pip install torchao |
| 8 | + |
| 9 | +If you prefer to use the nightly release, you can install torchao using the following |
| 10 | +command instead:: |
| 11 | + |
| 12 | + pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 |
| 13 | + |
| 14 | +torchao is compatible with the latest 3 major versions of PyTorch, which you will also |
| 15 | +need to install (`detailed instructions <https://pytorch.org/get-started/locally/>`__):: |
| 16 | + |
| 17 | + pip install torch |
| 18 | + |
| 19 | + |
| 20 | +First Quantization Example |
| 21 | +========================== |
| 22 | + |
| 23 | +The main entry point for quantization in torchao is the `quantize_ <https://pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_>`__ API. |
| 24 | +This function mutates your model inplace to insert the custom quantization logic based |
| 25 | +on what the user configures. All code in this guide can be found in this `example script <https://github.com/pytorch/ao/blob/main/scripts/quick_start.py>`__. |
| 26 | +First, let's set up our toy model: |
| 27 | + |
| 28 | +.. code:: py |
| 29 | +
|
| 30 | + import copy |
| 31 | + import torch |
| 32 | + |
| 33 | + class ToyLinearModel(torch.nn.Module): |
| 34 | + def __init__(self, m: int, n: int, k: int): |
| 35 | + super().__init__() |
| 36 | + self.linear1 = torch.nn.Linear(m, n, bias=False) |
| 37 | + self.linear2 = torch.nn.Linear(n, k, bias=False) |
| 38 | + |
| 39 | + def forward(self, x): |
| 40 | + x = self.linear1(x) |
| 41 | + x = self.linear2(x) |
| 42 | + return x |
| 43 | + |
| 44 | + model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") |
| 45 | + |
| 46 | + # Optional: compile model for faster inference and generation |
| 47 | + model = torch.compile(model, mode="max-autotune", fullgraph=True) |
| 48 | + model_bf16 = copy.deepcopy(model) |
| 49 | +
|
| 50 | +Now we call our main quantization API to quantize the linear weights |
| 51 | +in the model to int4 inplace. More specifically, this applies uint4 |
| 52 | +weight-only asymmetric per-group quantization, leveraging the |
| 53 | +`tinygemm int4mm CUDA kernel <https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/cuda/int4mm.cu#L1097>`__ |
| 54 | +for efficient mixed dtype matrix multiplication: |
| 55 | + |
| 56 | +.. code:: py |
| 57 | +
|
| 58 | + # torch 2.4+ only |
| 59 | + from torchao.quantization import int4_weight_only, quantize_ |
| 60 | + quantize_(model, int4_weight_only(group_size=32)) |
| 61 | +
|
| 62 | +The quantized model is now ready to use! Note that the quantization |
| 63 | +logic is inserted through tensor subclasses, so there is no change |
| 64 | +to the overall model structure; only the weights tensors are updated, |
| 65 | +but `nn.Linear` modules stay as `nn.Linear` modules: |
| 66 | + |
| 67 | +.. code:: py |
| 68 | +
|
| 69 | + >>> model.linear1 |
| 70 | + Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15)) |
| 71 | +
|
| 72 | + >>> model.linear2 |
| 73 | + Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15)) |
| 74 | +
|
| 75 | +First, verify that the int4 quantized model is roughly a quarter of |
| 76 | +the size of the original bfloat16 model: |
| 77 | + |
| 78 | +.. code:: py |
| 79 | +
|
| 80 | + >>> import os |
| 81 | + >>> torch.save(model, "/tmp/int4_model.pt") |
| 82 | + >>> torch.save(model_bf16, "/tmp/bfloat16_model.pt") |
| 83 | + >>> int4_model_size_mb = os.path.getsize("/tmp/int4_model.pt") / 1024 / 1024 |
| 84 | + >>> bfloat16_model_size_mb = os.path.getsize("/tmp/bfloat16_model.pt") / 1024 / 1024 |
| 85 | +
|
| 86 | + >>> print("int4 model size: %.2f MB" % int4_model_size_mb) |
| 87 | + int4 model size: 1.25 MB |
| 88 | +
|
| 89 | + >>> print("bfloat16 model size: %.2f MB" % bfloat16_model_size_mb) |
| 90 | + bfloat16 model size: 4.00 MB |
| 91 | +
|
| 92 | +Next, we demonstrate that not only is the quantized model smaller, |
| 93 | +it is also much faster! |
| 94 | + |
| 95 | +.. code:: py |
| 96 | +
|
| 97 | + from torchao.utils import ( |
| 98 | + TORCH_VERSION_AT_LEAST_2_5, |
| 99 | + benchmark_model, |
| 100 | + unwrap_tensor_subclass, |
| 101 | + ) |
| 102 | + |
| 103 | + # Temporary workaround for tensor subclass + torch.compile |
| 104 | + # Only needed for torch version < 2.5 |
| 105 | + if not TORCH_VERSION_AT_LEAST_2_5: |
| 106 | + unwrap_tensor_subclass(model) |
| 107 | + |
| 108 | + num_runs = 100 |
| 109 | + torch._dynamo.reset() |
| 110 | + example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) |
| 111 | + bf16_time = benchmark_model(model_bf16, num_runs, example_inputs) |
| 112 | + int4_time = benchmark_model(model, num_runs, example_inputs) |
| 113 | + |
| 114 | + print("bf16 mean time: %0.3f ms" % bf16_time) |
| 115 | + print("int4 mean time: %0.3f ms" % int4_time) |
| 116 | + print("speedup: %0.1fx" % (bf16_time / int4_time)) |
| 117 | +
|
| 118 | +On a single A100 GPU with 80GB memory, this prints:: |
| 119 | + |
| 120 | + bf16 mean time: 30.393 ms |
| 121 | + int4 mean time: 4.410 ms |
| 122 | + speedup: 6.9x |
| 123 | + |
| 124 | + |
| 125 | +Next Steps |
| 126 | +========== |
| 127 | + |
| 128 | +In this quick start guide, we learned how to quantize a simple model with |
| 129 | +torchao. To learn more about the different workflows supported in torchao, |
| 130 | +see our main `README <https://github.com/pytorch/ao/blob/main/README.md>`__. |
| 131 | +For a more detailed overview of quantization in torchao, visit |
| 132 | +`this page <quantization.html>`__. |
| 133 | + |
| 134 | +Finally, if you would like to contribute to torchao, don't forget to check |
| 135 | +out our `contributor guide <contributor_guide.html>`__ and our list of |
| 136 | +`good first issues <https://github.com/pytorch/ao/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22>`__ on Github! |
0 commit comments