Skip to content

Commit 1c9e446

Browse files
committed
[wip] configs configs configs!
Summary: POC for: * decoupling configuration from transformation * stop passing obscure stateful callables around * enable printing of configuration * reduce amount of context switching to navigate the logic from `quantize_` to quantizing a single module TODO more polish before wider discussion. Test Plan: ``` pytest test/quantization/test_quant_api.py -s -x -k test_int4_weight_only_numerics pytest test/quantization/test_qat.py -s -x -k test_quantize_api_standalone pytest test/quantization/test_qat.py -s -x -k test_quantize_api_convert_path ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 5f5330c5b9c1bdb5df12f3efebd559a42927984c ghstack-comment-id: 2607756510 Pull Request resolved: #1595
1 parent 32a51ec commit 1c9e446

File tree

10 files changed

+328
-152
lines changed

10 files changed

+328
-152
lines changed

test/dtypes/test_affine_quantized.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
run_tests,
99
)
1010

11+
from torchao.core.config import AOBaseConfig
1112
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
1213
from torchao.quantization import (
1314
float8_weight_only,
@@ -16,6 +17,7 @@
1617
int8_dynamic_activation_int4_weight,
1718
int8_dynamic_activation_int8_weight,
1819
int8_weight_only,
20+
quantize_,
1921
)
2022
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
2123
from torchao.utils import (
@@ -82,7 +84,8 @@ def test_tensor_core_layout_transpose(self):
8284
t = linear.weight
8385
shape = t.shape
8486
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
85-
ql = apply_int4_weight_only_quant(linear)
87+
quantize_(linear, apply_int4_weight_only_quant)
88+
ql = linear
8689
aqt = ql.weight
8790
aqt_shape = aqt.shape
8891
self.assertEqual(aqt_shape, shape)
@@ -102,7 +105,12 @@ def test_tensor_core_layout_transpose(self):
102105
)
103106
def test_weights_only(self, apply_quant):
104107
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
105-
ql = apply_quant(linear)
108+
if isinstance(apply_quant, AOBaseConfig):
109+
quantize_(linear, apply_quant)
110+
ql = linear
111+
else:
112+
# TODO(#1690): delete this once config migration is done
113+
ql = apply_quant(linear)
106114
with tempfile.NamedTemporaryFile() as f:
107115
torch.save(ql.state_dict(), f)
108116
f.seek(0)
@@ -181,7 +189,12 @@ def apply_uint6_weight_only_quant(linear):
181189
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
182190
def test_print_quantized_module(self, apply_quant):
183191
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
184-
ql = apply_quant(linear)
192+
if isinstance(apply_quant, AOBaseConfig):
193+
quantize_(linear, apply_quant)
194+
ql = linear
195+
else:
196+
# TODO(#1690): delete this once config migration is done
197+
ql = apply_quant(linear)
185198
assert "AffineQuantizedTensor" in str(ql)
186199

187200

@@ -195,7 +208,11 @@ def test_flatten_unflatten(self, device, dtype):
195208
apply_quant_list = get_quantization_functions(False, True, device)
196209
for apply_quant in apply_quant_list:
197210
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
198-
ql = apply_quant(linear)
211+
if isinstance(apply_quant, AOBaseConfig):
212+
quantize_(linear, apply_quant)
213+
else:
214+
# TODO(#1690): delete this once config migration is done
215+
ql = apply_quant(linear)
199216
lp_tensor = ql.weight
200217
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
201218
tensor_data_dict = {

test/hqq/test_hqq_affine.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
MappingType,
77
ZeroPointDomain,
88
int4_weight_only,
9+
quantize_,
910
uintx_weight_only,
1011
)
1112
from torchao.utils import (
@@ -51,9 +52,9 @@ def _eval_hqq(dtype):
5152
)
5253
dummy_linear.weight.data = W
5354
if dtype == torch.uint4:
54-
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(
55-
dummy_linear
56-
).weight
55+
config = int4_weight_only(group_size=max(block_size), use_hqq=True)
56+
quantize_(dummy_linear, config)
57+
q_tensor_hqq = dummy_linear.weight
5758
else:
5859
q_tensor_hqq = uintx_weight_only(
5960
dtype, group_size=max(block_size), use_hqq=True

test/quantization/test_qat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
11851185
@unittest.skipIf(
11861186
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
11871187
)
1188-
def test_quantize_api(self):
1188+
def test_quantize_api_standalone(self):
11891189
"""
11901190
Test that the following:
11911191

test/quantization/test_quant_api.py

+25
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
Int4WeightOnlyQuantizedLinearWeight,
4141
Int8WeightOnlyQuantizedLinearWeight,
4242
)
43+
from torchao.quantization.utils import compute_error
4344
from torchao.utils import (
4445
TORCH_VERSION_AT_LEAST_2_3,
4546
TORCH_VERSION_AT_LEAST_2_4,
@@ -783,6 +784,30 @@ def test_int4wo_cpu(self, dtype, x_dim):
783784
assert "_weight_int4pack_mm_for_cpu" in code[0]
784785
assert "aten.mm.default" not in code[0]
785786

787+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
788+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
789+
def test_int4_weight_only_numerics(self):
790+
"""
791+
Simple test of e2e int4_weight_only workflow, comparing numerics
792+
to a bfloat16 baseline.
793+
"""
794+
# set up inputs
795+
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
796+
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
797+
# is that expected?
798+
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
799+
m_int4_wo = copy.deepcopy(m_ref)
800+
801+
# quantize
802+
quantize_(m_int4_wo, int4_weight_only())
803+
804+
with torch.no_grad():
805+
y_ref = m_ref(x)
806+
y_int4_wo = m_int4_wo(x)
807+
808+
sqnr = compute_error(y_ref, y_int4_wo)
809+
assert sqnr >= 20, f"SQNR {sqnr} is too low"
810+
786811

787812
class TestMultiTensorFlow(TestCase):
788813
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")

torchao/core/__init__.py

Whitespace-only changes.

torchao/core/config.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import abc
2+
3+
4+
class AOBaseConfig(abc.ABC):
5+
"""
6+
If a workflow config inherits from this then `quantize_` knows
7+
how to a apply it to a model. For example::
8+
9+
# user facing code
10+
class WorkflowFooConfig(AOBaseConfig): ...
11+
# configuration for workflow `Foo` is defined here
12+
bar = 'baz'
13+
14+
# non user facing code
15+
@register_quantize_module_handler(WorkflowFooConfig)
16+
def _transform(
17+
mod: torch.nn.Module,
18+
config: WorkflowFooConfig,
19+
) -> torch.nn.Module:
20+
# the transform is implemented here, usually a tensor sublass
21+
# weight swap or a module swap
22+
...
23+
24+
# then, the user calls `quantize_` with a config, and `_transform` is called
25+
# under the hood by `quantize_.
26+
27+
"""
28+
29+
pass

torchao/quantization/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
78
from torchao.kernel import (
89
int_scaled_matmul,
910
safe_int_mm,
@@ -85,6 +86,7 @@
8586
swap_linear_with_smooth_fq_linear,
8687
)
8788
from .subclass import * # noqa: F403
89+
from .transform_module import register_quantize_module_handler
8890
from .unified import Quantizer, TwoStepQuantizer
8991
from .utils import (
9092
compute_error,
@@ -144,6 +146,8 @@
144146
# operators/kernels
145147
"safe_int_mm",
146148
"int_scaled_matmul",
149+
# registration of module transforms for quantize_
150+
"register_quantize_module_handler",
147151
# dataclasses and types
148152
"MappingType",
149153
"ZeroPointDomain",

torchao/quantization/qat/api.py

+70-48
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, Callable, List, Optional, Union
8+
from typing import Any, List, Optional, Union
99

1010
import torch
1111

12+
from torchao.core.config import AOBaseConfig
1213
from torchao.quantization.granularity import (
1314
Granularity,
1415
PerAxis,
@@ -22,6 +23,9 @@
2223
TorchAODType,
2324
ZeroPointDomain,
2425
)
26+
from torchao.quantization.transform_module import (
27+
register_quantize_module_handler,
28+
)
2529
from torchao.quantization.unified import TwoStepQuantizer
2630

2731

@@ -241,12 +245,26 @@ def __setattr__(self, name: str, value: Any):
241245
super().__setattr__(name, value)
242246

243247

244-
def intx_quantization_aware_training(
245-
activation_config: Optional[FakeQuantizeConfig] = None,
246-
weight_config: Optional[FakeQuantizeConfig] = None,
247-
) -> Callable:
248+
@dataclass
249+
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
250+
activation_config: Optional[FakeQuantizeConfig] = None
251+
weight_config: Optional[FakeQuantizeConfig] = None
252+
253+
254+
# for BC
255+
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
256+
257+
258+
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
259+
def _intx_quantization_aware_training_transform(
260+
module: torch.nn.Module,
261+
config: IntXQuantizationAwareTrainingConfig,
262+
) -> torch.nn.Module:
248263
"""
249-
Return a function that applies fake quantization to a `torch.nn.Module`.
264+
THIS IS NOT A PUBLIC API - any usage of this outside of torchao
265+
can break at any time.
266+
267+
Apply fake quantization to a `torch.nn.Module`.
250268
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
251269
252270
Example usage::
@@ -261,45 +279,40 @@ def intx_quantization_aware_training(
261279
)
262280
quantize_(
263281
model,
264-
intx_quantization_aware_training(activation_config, weight_config),
282+
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
265283
)
266284
267285
Note: If the returned function is applied on a module that is not
268286
`torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
269287
`torch.nn.Embedding` with an activation config, then we will raise
270288
ValueError as these are not supported.
271289
"""
272-
273-
def _insert_fake_quantize(mod: torch.nn.Module):
274-
"""
275-
Swap the given module with its corresponding fake quantized version.
276-
"""
277-
from .embedding import FakeQuantizedEmbedding
278-
from .linear import FakeQuantizedLinear
279-
280-
if isinstance(mod, torch.nn.Linear):
281-
return FakeQuantizedLinear.from_linear(
282-
mod,
283-
activation_config,
284-
weight_config,
285-
)
286-
elif isinstance(mod, torch.nn.Embedding):
287-
if activation_config is not None:
288-
raise ValueError(
289-
"Activation fake quantization is not supported for embedding"
290-
)
291-
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
292-
else:
290+
from .embedding import FakeQuantizedEmbedding
291+
from .linear import FakeQuantizedLinear
292+
293+
mod = module
294+
activation_config = config.activation_config
295+
weight_config = config.weight_config
296+
297+
if isinstance(mod, torch.nn.Linear):
298+
return FakeQuantizedLinear.from_linear(
299+
mod,
300+
activation_config,
301+
weight_config,
302+
)
303+
elif isinstance(mod, torch.nn.Embedding):
304+
if activation_config is not None:
293305
raise ValueError(
294-
"Module of type '%s' does not have QAT support" % type(mod)
306+
"Activation fake quantization is not supported for embedding"
295307
)
308+
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
309+
else:
310+
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
296311

297-
return _insert_fake_quantize
298312

299-
300-
def from_intx_quantization_aware_training() -> Callable:
313+
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
301314
"""
302-
Return a function that converts a model with fake quantized modules,
315+
Object that knows how to convert a model with fake quantized modules,
303316
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
304317
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
305318
back to model with the original, corresponding modules without
@@ -311,26 +324,35 @@ def from_intx_quantization_aware_training() -> Callable:
311324
from torchao.quantization import quantize_
312325
quantize_(
313326
model_with_fake_quantized_linears,
314-
from_intx_quantization_aware_training(),
327+
FromIntXQuantizationAwareTrainingConfig(),
315328
)
316329
"""
317330

318-
def _remove_fake_quantize(mod: torch.nn.Module):
319-
"""
320-
If the given module is a fake quantized module, return the original
321-
corresponding version of the module without fake quantization.
322-
"""
323-
from .embedding import FakeQuantizedEmbedding
324-
from .linear import FakeQuantizedLinear
331+
pass
332+
333+
334+
# for BC
335+
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
325336

326-
if isinstance(mod, FakeQuantizedLinear):
327-
return mod.to_linear()
328-
elif isinstance(mod, FakeQuantizedEmbedding):
329-
return mod.to_embedding()
330-
else:
331-
return mod
332337

333-
return _remove_fake_quantize
338+
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
339+
def _from_intx_quantization_aware_training_transform(
340+
mod: torch.nn.Module,
341+
config: FromIntXQuantizationAwareTrainingConfig,
342+
) -> torch.nn.Module:
343+
"""
344+
If the given module is a fake quantized module, return the original
345+
corresponding version of the module without fake quantization.
346+
"""
347+
from .embedding import FakeQuantizedEmbedding
348+
from .linear import FakeQuantizedLinear
349+
350+
if isinstance(mod, FakeQuantizedLinear):
351+
return mod.to_linear()
352+
elif isinstance(mod, FakeQuantizedEmbedding):
353+
return mod.to_embedding()
354+
else:
355+
return mod
334356

335357

336358
class ComposableQATQuantizer(TwoStepQuantizer):

0 commit comments

Comments
 (0)