Skip to content

Commit 7361238

Browse files
committed
temp
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 1240b19 commit 7361238

File tree

4 files changed

+75
-3
lines changed

4 files changed

+75
-3
lines changed

docs/source/contributor_guide.rst

+4-2
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ A tensor subclass needs to define a few basic methods: ``__new__``, ``__init__``
287287
and also dispatch functions for torch functions ``__torch_function__`` and aten ops ``__torch_dispatch__``.
288288

289289
Here is an example of basic structure::
290+
290291
# check out docs in https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L437
291292
from torchao.utils import TorchAOBaseTensor
292293

@@ -374,12 +375,12 @@ Operator Support
374375
~~~~~~~~~~~~~~~~
375376
There are two types of operator support, torch function and aten ops. For torch functions (e.g. ``torch.nn.functional.linear``), we’ll need to overwrite ``__torch_function__`` callback in the Tensor subclass, for aten ops (e.g. ``torch.ops.aten.mm``), we’ll need to overwrite ``__torch_dispatch__`` callback function.
376377

377-
For a new dtype, we’d like people to define the following decorator::
378-
if your dtype class is inherited from `torchao.utils.TorchAoBaseTensor`, you can do:
378+
For a new dtype, we’d like people to define the following decorator. If your dtype class is inherited from `torchao.utils.TorchAoBaseTensor`, you can do::
379379

380380
implements = my_dtype_tensor_cls.implements
381381

382382
And we can implement the operator dispatch with the following::
383+
383384
# Example for torch_function dispatch for torch.nn.functional.linear
384385
def _quantized_linear_op(input_tensor, weight_tensor, bias):
385386
if isinstance(input_tensor, MyDtypeTensor):
@@ -426,6 +427,7 @@ What ops do we need to overwrite? This depends on the model we are trying to qua
426427
``__torch_dispatch__``: ``torch.ops.aten.addmm.default``, ``torch.ops.aten.mm.default``, ``torch.ops.aten.detach.default``, ``torch.ops.aten.t.default``
427428

428429
You can also find the ops that can be overwritten in ``__torch_function__`` or ``__torch_dispatch__`` with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see Optimized Operators section for more details)::
430+
429431
class M(torch.nn.Module):
430432
def __init__(self) -> None:
431433
super().__init__()

docs/source/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ Welcome to the torchao Documentation
102102
:caption: Tutorials
103103

104104
serialization
105-
105+
subclass_basic

docs/source/subclass_basic.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Quantization with Tensor Subclasses
2+
===================================
3+
4+
Coming soon.
5+

scripts/module_swap_example.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Tuple
2+
import torch
3+
4+
5+
class ToyModel(torch.nn.Module):
6+
def __init__(self, m: int, n: int, k: int):
7+
super().__init__()
8+
self.linear1 = torch.nn.Linear(m, n, bias=False)
9+
self.linear2 = torch.nn.Linear(n, k, bias=False)
10+
11+
def forward(self, x):
12+
x = self.linear1(x)
13+
x = self.linear2(x)
14+
return x
15+
16+
17+
class QuantizedLinear(torch.nn.Linear):
18+
"""
19+
Linear module that performs dynamic and symmetric weight-only
20+
int8 quantization.
21+
"""
22+
def forward(self, x: torch.Tensor) -> torch.Tensor:
23+
w_int8, scale = int8_symmetric_quantize(self.weight)
24+
return torch.mm(x, w_int8.t().to(x.dtype)) * scale.t()
25+
26+
@classmethod
27+
def from_float(cls, mod: torch.nn.Linear):
28+
new_linear = cls(mod.in_features, mod.out_features, mod.bias)
29+
new_linear.weight = mod.weight
30+
return new_linear
31+
32+
33+
def int8_symmetric_quantize(
34+
fp32_tensor: torch.Tensor,
35+
) -> Tuple[torch.Tensor, torch.Tensor]:
36+
"""
37+
Symmetrically quantize the torch.float32 tensor into torch.int8.
38+
Return a 2-tuple of (quantized value, scale).
39+
"""
40+
quant_min = -128
41+
quant_max = 127
42+
min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False)
43+
max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False)
44+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
45+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
46+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
47+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
48+
scale = scale.view(fp32_tensor.shape[0], -1)
49+
out = torch.round(fp32_tensor * (1.0 / scale))
50+
out = torch.clamp(out, quant_min, quant_max).to(torch.int8)
51+
return out, scale
52+
53+
54+
if __name__ == "__main__":
55+
model = ToyModel(64, 128, 32).cuda()
56+
example_inputs = torch.randn((1, 64), dtype=torch.float32, device="cuda")
57+
58+
# Swap torch.nn.Linear with QuantizedLinear
59+
for name, child in model.named_children():
60+
if type(child) == torch.nn.Linear:
61+
new_linear = QuantizedLinear.from_float(child)
62+
setattr(model, name, new_linear)
63+
64+
print("quantized model: ", model)
65+
print("output: ", model(example_inputs))

0 commit comments

Comments
 (0)