Skip to content

Commit 4e4f4df

Browse files
authored
Add quick start guide for first time users (#1611)
Documentation in torchao has been pretty low-level and geared towards developers so far. This commit adds a basic quick start guide for first time users to get familiar with our main quantization flow.
1 parent 0fae693 commit 4e4f4df

9 files changed

+213
-25
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ docs/dev
262262
docs/build
263263
docs/source/tutorials/*
264264
docs/source/gen_modules/*
265-
docs/source/sg_execution_times
265+
docs/source/sg_execution_times.rst
266266

267267
# LevelDB files
268268
*.sst

docs/source/contributor_guide.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torchao Contributor Guide
1+
Contributor Guide
22
-------------------------
33

44
.. toctree::

docs/source/getting-started.rst

-4
This file was deleted.

docs/source/index.rst

+8-9
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
Welcome to the torchao Documentation
2-
=======================================
2+
====================================
33

4-
`torchao <https://github.com/pytorch/ao>`__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on:
5-
6-
1. Getting Started
7-
2. Developer Notes
8-
3. API Reference
9-
4. Tutorials
4+
`torchao <https://github.com/pytorch/ao>`__ is a library for custom data types and optimizations.
5+
Quantize and sparsify weights, gradients, optimizers, and activations for inference and training
6+
using native PyTorch. Please checkout torchao `README <https://github.com/pytorch/ao#torchao-pytorch-architecture-optimization>`__
7+
for an overall introduction to the library and recent highlight and updates.
108

119
.. toctree::
1210
:glob:
1311
:maxdepth: 1
1412
:caption: Getting Started
1513

16-
getting-started
17-
sparsity
14+
quick_start
1815

1916
.. toctree::
2017
:glob:
2118
:maxdepth: 1
2219
:caption: Developer Notes
2320

21+
quantization
22+
sparsity
2423
contributor_guide
2524

2625
.. toctree::

docs/source/overview.rst

-4
This file was deleted.

docs/source/quantization.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Quantization
2-
============
1+
Quantization Overview
2+
---------------------
33

4-
TBA
4+
Coming soon!

docs/source/quick_start.rst

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
Quick Start Guide
2+
-----------------
3+
4+
In this quick start guide, we will explore how to perform basic quantization using torchao.
5+
First, install the latest stable torchao release::
6+
7+
pip install torchao
8+
9+
If you prefer to use the nightly release, you can install torchao using the following
10+
command instead::
11+
12+
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121
13+
14+
torchao is compatible with the latest 3 major versions of PyTorch, which you will also
15+
need to install (`detailed instructions <https://pytorch.org/get-started/locally/>`__)::
16+
17+
pip install torch
18+
19+
20+
First Quantization Example
21+
==========================
22+
23+
The main entry point for quantization in torchao is the `quantize_ <https://pytorch.org/ao/stable/generated/torchao.quantization.quantize_.html#torchao.quantization.quantize_>`__ API.
24+
This function mutates your model inplace to insert the custom quantization logic based
25+
on what the user configures. All code in this guide can be found in this `example script <https://github.com/pytorch/ao/blob/main/scripts/quick_start.py>`__.
26+
First, let's set up our toy model:
27+
28+
.. code:: py
29+
30+
import copy
31+
import torch
32+
33+
class ToyLinearModel(torch.nn.Module):
34+
def __init__(self, m: int, n: int, k: int):
35+
super().__init__()
36+
self.linear1 = torch.nn.Linear(m, n, bias=False)
37+
self.linear2 = torch.nn.Linear(n, k, bias=False)
38+
39+
def forward(self, x):
40+
x = self.linear1(x)
41+
x = self.linear2(x)
42+
return x
43+
44+
model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
45+
46+
# Optional: compile model for faster inference and generation
47+
model = torch.compile(model, mode="max-autotune", fullgraph=True)
48+
model_bf16 = copy.deepcopy(model)
49+
50+
Now we call our main quantization API to quantize the linear weights
51+
in the model to int4 inplace. More specifically, this applies uint4
52+
weight-only asymmetric per-group quantization, leveraging the
53+
`tinygemm int4mm CUDA kernel <https://github.com/pytorch/pytorch/blob/a8d6afb511a69687bbb2b7e88a3cf67917e1697e/aten/src/ATen/native/cuda/int4mm.cu#L1097>`__
54+
for efficient mixed dtype matrix multiplication:
55+
56+
.. code:: py
57+
58+
# torch 2.4+ only
59+
from torchao.quantization import int4_weight_only, quantize_
60+
quantize_(model, int4_weight_only(group_size=32))
61+
62+
The quantized model is now ready to use! Note that the quantization
63+
logic is inserted through tensor subclasses, so there is no change
64+
to the overall model structure; only the weights tensors are updated,
65+
but `nn.Linear` modules stay as `nn.Linear` modules:
66+
67+
.. code:: py
68+
69+
>>> model.linear1
70+
Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15))
71+
72+
>>> model.linear2
73+
Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15))
74+
75+
First, verify that the int4 quantized model is roughly a quarter of
76+
the size of the original bfloat16 model:
77+
78+
.. code:: py
79+
80+
>>> import os
81+
>>> torch.save(model, "/tmp/int4_model.pt")
82+
>>> torch.save(model_bf16, "/tmp/bfloat16_model.pt")
83+
>>> int4_model_size_mb = os.path.getsize("/tmp/int4_model.pt") / 1024 / 1024
84+
>>> bfloat16_model_size_mb = os.path.getsize("/tmp/bfloat16_model.pt") / 1024 / 1024
85+
86+
>>> print("int4 model size: %.2f MB" % int4_model_size_mb)
87+
int4 model size: 1.25 MB
88+
89+
>>> print("bfloat16 model size: %.2f MB" % bfloat16_model_size_mb)
90+
bfloat16 model size: 4.00 MB
91+
92+
Next, we demonstrate that not only is the quantized model smaller,
93+
it is also much faster!
94+
95+
.. code:: py
96+
97+
from torchao.utils import (
98+
TORCH_VERSION_AT_LEAST_2_5,
99+
benchmark_model,
100+
unwrap_tensor_subclass,
101+
)
102+
103+
# Temporary workaround for tensor subclass + torch.compile
104+
# Only needed for torch version < 2.5
105+
if not TORCH_VERSION_AT_LEAST_2_5:
106+
unwrap_tensor_subclass(model)
107+
108+
num_runs = 100
109+
torch._dynamo.reset()
110+
example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),)
111+
bf16_time = benchmark_model(model_bf16, num_runs, example_inputs)
112+
int4_time = benchmark_model(model, num_runs, example_inputs)
113+
114+
print("bf16 mean time: %0.3f ms" % bf16_time)
115+
print("int4 mean time: %0.3f ms" % int4_time)
116+
print("speedup: %0.1fx" % (bf16_time / int4_time))
117+
118+
On a single A100 GPU with 80GB memory, this prints::
119+
120+
bf16 mean time: 30.393 ms
121+
int4 mean time: 4.410 ms
122+
speedup: 6.9x
123+
124+
125+
Next Steps
126+
==========
127+
128+
In this quick start guide, we learned how to quantize a simple model with
129+
torchao. To learn more about the different workflows supported in torchao,
130+
see our main `README <https://github.com/pytorch/ao/blob/main/README.md>`__.
131+
For a more detailed overview of quantization in torchao, visit
132+
`this page <quantization.html>`__.
133+
134+
Finally, if you would like to contribute to torchao, don't forget to check
135+
out our `contributor guide <contributor_guide.html>`__ and our list of
136+
`good first issues <https://github.com/pytorch/ao/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22>`__ on Github!

docs/source/sparsity.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Sparsity
2-
--------
1+
Sparsity Overview
2+
-----------------
33

44
Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1).
55

@@ -38,7 +38,7 @@ Given a target sparsity pattern, pruning/sparsifying a model can then be thought
3838

3939

4040
* **Accuracy** - How can I find a set of sparse weights which satisfy my target sparsity pattern that minimize the accuracy degradation of my model?
41-
* **Perforance** - How can I accelerate my sparse weights for inference and reduce memory overhead?
41+
* **Performance** - How can I accelerate my sparse weights for inference and reduce memory overhead?
4242

4343
Our workflow is designed to consist of two parts that answer each question independently:
4444

scripts/quick_start.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import copy
2+
3+
import torch
4+
5+
from torchao.quantization import int4_weight_only, quantize_
6+
from torchao.utils import (
7+
TORCH_VERSION_AT_LEAST_2_5,
8+
benchmark_model,
9+
unwrap_tensor_subclass,
10+
)
11+
12+
# ================
13+
# | Set up model |
14+
# ================
15+
16+
17+
class ToyLinearModel(torch.nn.Module):
18+
def __init__(self, m: int, n: int, k: int):
19+
super().__init__()
20+
self.linear1 = torch.nn.Linear(m, n, bias=False)
21+
self.linear2 = torch.nn.Linear(n, k, bias=False)
22+
23+
def forward(self, x):
24+
x = self.linear1(x)
25+
x = self.linear2(x)
26+
return x
27+
28+
29+
model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
30+
31+
# Optional: compile model for faster inference and generation
32+
model = torch.compile(model, mode="max-autotune", fullgraph=True)
33+
model_bf16 = copy.deepcopy(model)
34+
35+
36+
# ========================
37+
# | torchao quantization |
38+
# ========================
39+
40+
# torch 2.4+ only
41+
quantize_(model, int4_weight_only(group_size=32))
42+
43+
44+
# =============
45+
# | Benchmark |
46+
# =============
47+
48+
# Temporary workaround for tensor subclass + torch.compile
49+
# Only needed for torch version < 2.5
50+
if not TORCH_VERSION_AT_LEAST_2_5:
51+
unwrap_tensor_subclass(model)
52+
53+
num_runs = 100
54+
torch._dynamo.reset()
55+
example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),)
56+
bf16_time = benchmark_model(model_bf16, num_runs, example_inputs)
57+
int4_time = benchmark_model(model, num_runs, example_inputs)
58+
59+
print("bf16 mean time: %0.3f ms" % bf16_time)
60+
print("int4 mean time: %0.3f ms" % int4_time)
61+
print("speedup: %0.1fx" % (bf16_time / int4_time))

0 commit comments

Comments
 (0)