Skip to content

Commit fbd4716

Browse files
committed
Enable range learning for QAT
**Summary:** This commit adds the option for QAT users to use range learning during training. Range learning means we train the scale and zero point instead of recomputing them based on the input at every iteration. Example usage: ``` import torch from torchao.quantization import quantize_ from torchao.quantization.qat import ( FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig, initialize_fake_quantizers, ) config = FakeQuantizeConfig( torch.int8, "per_channel", is_dynamic=False, range_learning=True, scale_precision=torch.float32, zero_point_precision=torch.float32, ) m = M() example_inputs = (torch.randn(16, 32),) quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config)) # New required step to turn scales and zero points into trainable # `nn.Parameters`, must be called before initializing the optimizer initialize_fake_quantizers(m, example_inputs) # initialize the optimizer # do training ``` **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_dynamic_and_range_learning python test/quantization/test_qat.py -k test_fake_quantizer_range_learning python test/quantization/test_qat.py -k test_qat_range_learning
1 parent 5549da8 commit fbd4716

File tree

8 files changed

+217
-19
lines changed

8 files changed

+217
-19
lines changed

test/quantization/test_qat.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import copy
1111
import unittest
12+
from typing import List
1213

1314
import torch
1415
import torch.nn.functional as F
@@ -26,7 +27,9 @@
2627
from torchao.quantization.qat.api import (
2728
ComposableQATQuantizer,
2829
FakeQuantizeConfig,
30+
IntXQuantizationAwareTrainingConfig,
2931
from_intx_quantization_aware_training,
32+
initialize_fake_quantizers,
3033
intx_quantization_aware_training,
3134
)
3235
from torchao.quantization.qat.embedding import (
@@ -99,6 +102,16 @@ def __init__(self):
99102
def example_inputs(self):
100103
return (torch.randn(1, 512).to(torch.float),)
101104

105+
def _get_all_weight_qparams(self) -> List[torch.Tensor]:
106+
return [
107+
self.linear1.weight_fake_quantizer.scale,
108+
self.linear1.weight_fake_quantizer.zero_point,
109+
self.sub.linear.weight_fake_quantizer.scale,
110+
self.sub.linear.weight_fake_quantizer.zero_point,
111+
self.linear2.weight_fake_quantizer.scale,
112+
self.linear2.weight_fake_quantizer.zero_point,
113+
]
114+
102115
def forward(self, x):
103116
x = self.linear1(x)
104117
x = self.sub(x)
@@ -996,6 +1009,21 @@ def test_fake_quantize_config_dtype(self):
9961009
FakeQuantizeConfig(TorchAODType.INT7, "per_token")
9971010
FakeQuantizeConfig(torch.int8, "per_token")
9981011

1012+
def test_fake_quantize_config_dynamic_and_range_learning(self):
1013+
"""
1014+
Test that `is_dynamic` and `range_learning` cannot both be set.
1015+
"""
1016+
FakeQuantizeConfig(
1017+
torch.int8, "per_channel", is_dynamic=True, range_learning=False
1018+
)
1019+
FakeQuantizeConfig(
1020+
torch.int8, "per_channel", is_dynamic=False, range_learning=True
1021+
)
1022+
with self.assertRaisesRegex(ValueError, "not compatible"):
1023+
FakeQuantizeConfig(
1024+
torch.int8, "per_channel", is_dynamic=True, range_learning=True
1025+
)
1026+
9991027
@unittest.skipIf(
10001028
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
10011029
)
@@ -1591,6 +1619,95 @@ def test_qat_8da4w_eps(self):
15911619
actual_out = converted_model.linear1(x)
15921620
torch.testing.assert_close(expected_out, actual_out, atol=0, rtol=0)
15931621

1622+
@unittest.skipIf(
1623+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1624+
)
1625+
def test_fake_quantizer_range_learning(self):
1626+
"""
1627+
Test that range learning requires `FakeQuantizer`s to be initialized correctly.
1628+
"""
1629+
config = FakeQuantizeConfig(
1630+
torch.int8,
1631+
"per_channel",
1632+
is_dynamic=False,
1633+
range_learning=True,
1634+
scale_precision=torch.float32,
1635+
zero_point_precision=torch.float32,
1636+
)
1637+
fake_quantizer = FakeQuantizer(config)
1638+
example_inputs = (torch.randn(2, 3),)
1639+
1640+
# Not initialized, should fail
1641+
self.assertFalse(fake_quantizer._initialized)
1642+
self.assertIsNone(fake_quantizer.scale)
1643+
self.assertIsNone(fake_quantizer.zero_point)
1644+
with self.assertRaisesRegex(
1645+
ValueError,
1646+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
1647+
"before initializing the optimizer and beginning training.",
1648+
):
1649+
fake_quantizer(*example_inputs)
1650+
1651+
# Should pass after initializing
1652+
initialize_fake_quantizers(fake_quantizer, example_inputs)
1653+
self.assertTrue(fake_quantizer._initialized)
1654+
self.assertIsInstance(fake_quantizer.scale, torch.nn.Parameter)
1655+
self.assertIsInstance(fake_quantizer.zero_point, torch.nn.Parameter)
1656+
self.assertTrue(fake_quantizer.scale.requires_grad)
1657+
self.assertTrue(fake_quantizer.zero_point.requires_grad)
1658+
fake_quantizer(*example_inputs)
1659+
1660+
@unittest.skipIf(
1661+
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
1662+
)
1663+
def test_qat_range_learning(self):
1664+
"""
1665+
Test end-to-end QAT flow with range learning.
1666+
"""
1667+
config = FakeQuantizeConfig(
1668+
torch.int8,
1669+
"per_channel",
1670+
is_dynamic=False,
1671+
range_learning=True,
1672+
scale_precision=torch.float32,
1673+
zero_point_precision=torch.float32,
1674+
)
1675+
m = M()
1676+
example_inputs = m.example_inputs()
1677+
quantize_(m, IntXQuantizationAwareTrainingConfig(weight_config=config))
1678+
1679+
# Not initialized, should fail
1680+
for t in m._get_all_weight_qparams():
1681+
self.assertIsNone(t)
1682+
with self.assertRaisesRegex(
1683+
ValueError,
1684+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
1685+
"before initializing the optimizer and beginning training.",
1686+
):
1687+
m(*example_inputs)
1688+
1689+
# Should pass after initializing
1690+
# All scales and zero points should be in `m.parameters()`
1691+
initialize_fake_quantizers(m, example_inputs)
1692+
params = set(m.parameters())
1693+
for t in m._get_all_weight_qparams():
1694+
self.assertIsInstance(t, torch.nn.Parameter)
1695+
self.assertTrue(t.requires_grad)
1696+
self.assertTrue(t in params)
1697+
m(*example_inputs)
1698+
1699+
# Simulate training
1700+
optimizer = torch.optim.SGD(
1701+
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
1702+
)
1703+
loss_fn = torch.nn.CrossEntropyLoss()
1704+
target = torch.randn(1, 512).float()
1705+
out = m(*example_inputs)
1706+
loss = loss_fn(out, target)
1707+
optimizer.zero_grad()
1708+
loss.backward()
1709+
optimizer.step()
1710+
15941711

15951712
if __name__ == "__main__":
15961713
unittest.main()

third_party/cutlass

Submodule cutlass updated 507 files

torchao/quantization/qat/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
FromIntXQuantizationAwareTrainingConfig,
55
IntXQuantizationAwareTrainingConfig,
66
from_intx_quantization_aware_training,
7+
initialize_fake_quantizers,
78
intx_quantization_aware_training,
89
)
910
from .embedding import (
@@ -17,11 +18,12 @@
1718
__all__ = [
1819
"ComposableQATQuantizer",
1920
"FakeQuantizeConfig",
20-
"Int4WeightOnlyQATQuantizer",
21+
"FromIntXQuantizationAwareTrainingConfig",
2122
"Int4WeightOnlyEmbeddingQATQuantizer",
23+
"Int4WeightOnlyQATQuantizer",
2224
"Int8DynActInt4WeightQATQuantizer",
25+
"IntXQuantizationAwareTrainingConfig",
26+
"initialize_fake_quantizers",
2327
"intx_quantization_aware_training",
2428
"from_intx_quantization_aware_training",
25-
"FromIntXQuantizationAwareTrainingConfig",
26-
"IntXQuantizationAwareTrainingConfig",
2729
]

torchao/quantization/qat/api.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, List, Optional, Union
8+
from typing import Any, List, Optional, Tuple, Union
99

1010
import torch
1111

@@ -51,7 +51,8 @@ class FakeQuantizeConfig:
5151
zero_point_precision: zero point dtype (default torch.int32)
5252
zero_point_domain: whether zero point is in integer (default) or float domain
5353
is_dynamic: whether to use dynamic (default) or static scale and zero points
54-
range_learning: whether to learn scale and zero points during training (coming soon)
54+
range_learning: whether to learn scale and zero points during training
55+
(default false), not compatible with `is_dynamic`.
5556
5657
kwargs (optional):
5758
group_size: size of each group in per group fake quantization,
@@ -123,6 +124,10 @@ def __init__(
123124
"Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes)
124125
)
125126

127+
# Dynamic is not compatible with range learning
128+
if is_dynamic and range_learning:
129+
raise ValueError("`is_dynamic` is not compatible with `range_learning`")
130+
126131
def _get_granularity(
127132
self,
128133
granularity: Union[Granularity, str, None],
@@ -394,3 +399,23 @@ def convert(
394399
for quantizer in self.quantizers:
395400
model = quantizer.convert(model)
396401
return model
402+
403+
404+
def initialize_fake_quantizers(
405+
model: torch.nn.Module,
406+
example_inputs: Tuple[Any, ...],
407+
) -> None:
408+
"""
409+
Initialize the scales and zero points on all
410+
:class:`~`torchao.quantization.qat.fake_quantizer.FakeQuantizer`
411+
in the model based on the provided example inputs.
412+
"""
413+
# avoid circular dependencies
414+
from torchao.quantization.qat.fake_quantizer import FakeQuantizer
415+
416+
def _set_initialized(m: torch.nn.Module):
417+
if isinstance(m, FakeQuantizer):
418+
m._initialized = True
419+
420+
model.apply(_set_initialized)
421+
model(*example_inputs)

torchao/quantization/qat/embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def to_embedding(self) -> torch.nn.Embedding:
9292
self.scale_grad_by_freq,
9393
self.sparse,
9494
device=self.weight.device,
95+
dtype=self.weight.dtype,
9596
)
9697
# In distributed training, the model may be instantiated
9798
# on the meta device, in which case there is no need to
@@ -116,6 +117,7 @@ def from_embedding(
116117
mod.sparse,
117118
weight_config=weight_config,
118119
device=mod.weight.device,
120+
dtype=mod.weight.dtype,
119121
)
120122
# In distributed training, the model may be instantiated
121123
# on the meta device, in which case there is no need to

torchao/quantization/qat/fake_quantizer.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .utils import (
3232
_fake_quantize_per_channel_group,
3333
_fake_quantize_per_token,
34+
_Round,
3435
)
3536

3637

@@ -46,32 +47,43 @@ def __init__(self, config: FakeQuantizeConfig):
4647
self.scale: Optional[torch.Tensor] = None
4748
self.zero_point: Optional[torch.Tensor] = None
4849

49-
# TODO: support range learinng
50-
if self.config.range_learning:
51-
raise NotImplementedError("Range learning is not supported yet")
50+
# For range learning only
51+
# TODO: make this configurable?
52+
self._scale_eps = 1e-9
53+
self._initialized = False
5254

53-
def forward(self, x: torch.Tensor):
55+
def forward(self, x: torch.Tensor) -> torch.Tensor:
5456
"""
5557
Apply fake quantization to the tensor based on the bit-width,
5658
granularity, symmetry, and other properties specified in the config.
5759
"""
5860
if not self.enabled:
5961
return x
6062

63+
if (
64+
self.config.range_learning
65+
and not self._initialized
66+
and (self.scale is None or self.zero_point is None)
67+
):
68+
raise ValueError(
69+
"Scales and zero points must be initialized for range learning. "
70+
"Please call `torchao.quantization.qat.initialize_fake_quantizers` "
71+
"before initializing the optimizer and beginning training."
72+
)
73+
6174
if isinstance(self.config.granularity, PerToken):
6275
return self._per_token_forward(x)
6376
elif isinstance(self.config.granularity, (PerAxis, PerGroup)):
6477
return self._per_channel_or_group_forward(x)
6578
else:
6679
raise ValueError("Unknown granularity '%s'" % self.config.granularity)
6780

68-
def _per_token_forward(self, x: torch.Tensor):
81+
def _per_token_forward(self, x: torch.Tensor) -> torch.Tensor:
6982
"""
7083
Perform per token fake quantization on the tensor.
7184
"""
7285
if self.config.is_symmetric:
7386
raise NotImplementedError("Symmetric per token is not supported yet")
74-
7587
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
7688
if self._should_compute_qparams():
7789
self.scale, self.zero_point = choose_qparams_affine(
@@ -85,9 +97,10 @@ def _per_token_forward(self, x: torch.Tensor):
8597
scale_dtype=self.config.scale_precision,
8698
zero_point_dtype=self.config.zero_point_precision,
8799
)
100+
self._maybe_update_qparams_for_range_learning()
88101
return _fake_quantize_per_token(x, self.scale, self.zero_point, qmin, qmax)
89102

90-
def _per_channel_or_group_forward(self, x: torch.Tensor):
103+
def _per_channel_or_group_forward(self, x: torch.Tensor) -> torch.Tensor:
91104
"""
92105
Perform per channel or per group fake quantization on the tensor.
93106
We express per channel using per group where the group size is the size
@@ -129,6 +142,7 @@ def _per_channel_or_group_forward(self, x: torch.Tensor):
129142
eps=self.config.eps,
130143
)
131144
self.zero_point = self.zero_point.to(zero_point_precision)
145+
self._maybe_update_qparams_for_range_learning()
132146

133147
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
134148
return _fake_quantize_per_channel_group(
@@ -147,6 +161,26 @@ def _should_compute_qparams(self) -> bool:
147161
"""
148162
return self.config.is_dynamic or self.scale is None or self.zero_point is None
149163

164+
def _maybe_update_qparams_for_range_learning(self) -> None:
165+
"""
166+
If range learning is enabled, turn scales and zero points into trainable parameters.
167+
This function is idempotent and should only be called once.
168+
"""
169+
if (
170+
not self.config.range_learning
171+
or isinstance(self.scale, torch.nn.Parameter)
172+
or isinstance(self.zero_point, torch.nn.Parameter)
173+
):
174+
return
175+
scale, zero_point = self.scale, self.zero_point
176+
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.config.dtype]
177+
# Stabilize range learning
178+
scale = torch.clamp(scale, min=self._scale_eps)
179+
zero_point = _Round.apply(zero_point)
180+
zero_point = torch.clamp(zero_point, qmin, qmax)
181+
self.scale = torch.nn.Parameter(scale, requires_grad=True)
182+
self.zero_point = torch.nn.Parameter(zero_point, requires_grad=True)
183+
150184
def __repr__(self) -> str:
151185
"""
152186
Return a human readable representation of this `FakeQuantizer` with config details.

torchao/quantization/qat/linear.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
_replace_linear_int4,
1919
groupwise_affine_quantize_tensor,
2020
)
21+
from torchao.quantization.granularity import PerGroup
2122
from torchao.quantization.quant_primitives import (
2223
TorchAODType,
2324
ZeroPointDomain,
@@ -83,12 +84,13 @@ def __init__(
8384

8485
# initialize weight fake quantizer
8586
if weight_config is not None:
86-
group_size = weight_config.group_size
87-
if group_size is not None and in_features % group_size != 0:
88-
raise ValueError(
89-
"in_features (%s) %% group_size (%s) must be == 0"
90-
% (in_features, group_size)
91-
)
87+
if isinstance(weight_config.granularity, PerGroup):
88+
group_size = weight_config.group_size
89+
if group_size is not None and in_features % group_size != 0:
90+
raise ValueError(
91+
"in_features (%s) %% group_size (%s) must be == 0"
92+
% (in_features, group_size)
93+
)
9294
self.weight_fake_quantizer = FakeQuantizer(weight_config)
9395
else:
9496
self.weight_fake_quantizer = None
@@ -108,6 +110,7 @@ def to_linear(self) -> torch.nn.Linear:
108110
self.out_features,
109111
self.bias is not None,
110112
device=self.weight.device,
113+
dtype=self.weight.dtype,
111114
)
112115
# In distributed training, the model may be instantiated
113116
# on the meta device, in which case there is no need to
@@ -131,6 +134,7 @@ def from_linear(
131134
activation_config=activation_config,
132135
weight_config=weight_config,
133136
device=mod.weight.device,
137+
dtype=mod.weight.dtype,
134138
)
135139
# In distributed training, the model may be instantiated
136140
# on the meta device, in which case there is no need to

0 commit comments

Comments
 (0)