Skip to content

Commit 2307f5b

Browse files
committed
[wip] configs configs configs!
Summary: POC for: * decoupling configuration from transformation * stop passing obscure stateful callables around * enable printing of configuration * reduce amount of context switching to navigate the logic from `quantize_` to quantizing a single module TODO more polish before wider discussion. Test Plan: ``` pytest test/quantization/test_quant_api.py -s -x -k test_int4_weight_only_numerics pytest test/quantization/test_qat.py -s -x -k test_quantize_api_standalone pytest test/quantization/test_qat.py -s -x -k test_quantize_api_convert_path ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 73e9a5c3bf03e2cb645cc0ea43bec162a5f4897e ghstack-comment-id: 2607756510 Pull Request resolved: #1595
1 parent 32d9b0b commit 2307f5b

File tree

8 files changed

+256
-143
lines changed

8 files changed

+256
-143
lines changed

test/dtypes/test_affine_quantized.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
run_tests,
99
)
1010

11+
from torchao.core.config import AOBaseWorkflowConfig
1112
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
1213
from torchao.quantization import (
1314
float8_weight_only,
1415
int4_weight_only,
1516
int8_dynamic_activation_int4_weight,
1617
int8_dynamic_activation_int8_weight,
1718
int8_weight_only,
19+
quantize_,
1820
)
1921
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
2022
from torchao.utils import (
@@ -186,7 +188,10 @@ def test_flatten_unflatten(self, device, dtype):
186188
apply_quant_list = get_quantization_functions(False, True, device)
187189
for apply_quant in apply_quant_list:
188190
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
189-
ql = apply_quant(linear)
191+
if isinstance(apply_quant, AOBaseWorkflowConfig):
192+
quantize_(linear, apply_quant)
193+
else:
194+
ql = apply_quant(linear)
190195
lp_tensor = ql.weight
191196
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
192197
tensor_data_dict = {

test/quantization/test_qat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
11851185
@unittest.skipIf(
11861186
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
11871187
)
1188-
def test_quantize_api(self):
1188+
def test_quantize_api_standalone(self):
11891189
"""
11901190
Test that the following:
11911191

test/quantization/test_quant_api.py

+25
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
Int4WeightOnlyQuantizedLinearWeight,
4141
Int8WeightOnlyQuantizedLinearWeight,
4242
)
43+
from torchao.quantization.utils import compute_error
4344
from torchao.utils import (
4445
TORCH_VERSION_AT_LEAST_2_3,
4546
TORCH_VERSION_AT_LEAST_2_4,
@@ -761,6 +762,30 @@ def reset_memory():
761762
assert param.is_cuda
762763
self.assertLess(memory_streaming, memory_baseline)
763764

765+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
766+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
767+
def test_int4_weight_only_numerics(self):
768+
"""
769+
Simple test of e2e int4_weight_only workflow, comparing numerics
770+
to a bfloat16 baseline.
771+
"""
772+
# set up inputs
773+
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
774+
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
775+
# is that expected?
776+
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
777+
m_int4_wo = copy.deepcopy(m_ref)
778+
779+
# quantize
780+
quantize_(m_int4_wo, int4_weight_only())
781+
782+
with torch.no_grad():
783+
y_ref = m_ref(x)
784+
y_int4_wo = m_int4_wo(x)
785+
786+
sqnr = compute_error(y_ref, y_int4_wo)
787+
assert sqnr >= 20, f"SQNR {sqnr} is too low"
788+
764789

765790
class TestMultiTensorFlow(TestCase):
766791
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")

torchao/core/__init__.py

Whitespace-only changes.

torchao/core/config.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import abc
2+
3+
4+
# directory location for this might need more polish
5+
class AOBaseWorkflowConfig(abc.ABC):
6+
"""
7+
If a workflow config inherits from this then `quantize_` knows
8+
what to do with it.
9+
10+
TODO write a better docblock.
11+
"""
12+
13+
pass
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Callable, Dict
2+
3+
import torch
4+
5+
from torchao.core.config import AOBaseWorkflowConfig
6+
7+
_QUANTIZE_CONFIG_HANDLER: Dict[
8+
AOBaseWorkflowConfig,
9+
Callable[[torch.nn.Module, AOBaseWorkflowConfig], torch.nn.Module],
10+
] = {}
11+
12+
13+
def register_quantize_module_handler(config_type):
14+
def decorator(func):
15+
_QUANTIZE_CONFIG_HANDLER[config_type] = func
16+
17+
return decorator

torchao/quantization/qat/api.py

+68-46
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, Callable, List, Optional, Union
8+
from typing import Any, List, Optional, Union
99

1010
import torch
1111

12+
from torchao.core.config import AOBaseWorkflowConfig
13+
from torchao.quantization._transform_module import (
14+
register_quantize_module_handler,
15+
)
1216
from torchao.quantization.granularity import (
1317
Granularity,
1418
PerAxis,
@@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any):
239243
super().__setattr__(name, value)
240244

241245

242-
def intx_quantization_aware_training(
243-
activation_config: Optional[FakeQuantizeConfig] = None,
244-
weight_config: Optional[FakeQuantizeConfig] = None,
245-
) -> Callable:
246+
@dataclass
247+
class IntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig):
248+
activation_config: Optional[FakeQuantizeConfig] = None
249+
weight_config: Optional[FakeQuantizeConfig] = None
250+
251+
252+
# for BC
253+
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
254+
255+
256+
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
257+
def _intx_quantization_aware_training_transform(
258+
module: torch.nn.Module,
259+
config: IntXQuantizationAwareTrainingConfig,
260+
) -> torch.nn.Module:
246261
"""
247-
Return a function that applies fake quantization to a `torch.nn.Module`.
262+
THIS IS NOT A PUBLIC API - any usage of this outside of torchao
263+
can break at any time.
264+
265+
Apply fake quantization to a `torch.nn.Module`.
248266
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
249267
250268
Example usage::
@@ -267,37 +285,32 @@ def intx_quantization_aware_training(
267285
`torch.nn.Embedding` with an activation config, then we will raise
268286
ValueError as these are not supported.
269287
"""
270-
271-
def _insert_fake_quantize(mod: torch.nn.Module):
272-
"""
273-
Swap the given module with its corresponding fake quantized version.
274-
"""
275-
from .embedding import FakeQuantizedEmbedding
276-
from .linear import FakeQuantizedLinear
277-
278-
if isinstance(mod, torch.nn.Linear):
279-
return FakeQuantizedLinear.from_linear(
280-
mod,
281-
activation_config,
282-
weight_config,
283-
)
284-
elif isinstance(mod, torch.nn.Embedding):
285-
if activation_config is not None:
286-
raise ValueError(
287-
"Activation fake quantization is not supported for embedding"
288-
)
289-
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
290-
else:
288+
from .embedding import FakeQuantizedEmbedding
289+
from .linear import FakeQuantizedLinear
290+
291+
mod = module
292+
activation_config = config.activation_config
293+
weight_config = config.weight_config
294+
295+
if isinstance(mod, torch.nn.Linear):
296+
return FakeQuantizedLinear.from_linear(
297+
mod,
298+
activation_config,
299+
weight_config,
300+
)
301+
elif isinstance(mod, torch.nn.Embedding):
302+
if activation_config is not None:
291303
raise ValueError(
292-
"Module of type '%s' does not have QAT support" % type(mod)
304+
"Activation fake quantization is not supported for embedding"
293305
)
306+
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
307+
else:
308+
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
294309

295-
return _insert_fake_quantize
296310

297-
298-
def from_intx_quantization_aware_training() -> Callable:
311+
class FromIntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig):
299312
"""
300-
Return a function that converts a model with fake quantized modules,
313+
Object that knows how to convert a model with fake quantized modules,
301314
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
302315
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
303316
back to model with the original, corresponding modules without
@@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable:
313326
)
314327
"""
315328

316-
def _remove_fake_quantize(mod: torch.nn.Module):
317-
"""
318-
If the given module is a fake quantized module, return the original
319-
corresponding version of the module without fake quantization.
320-
"""
321-
from .embedding import FakeQuantizedEmbedding
322-
from .linear import FakeQuantizedLinear
329+
pass
330+
331+
332+
# for BC
333+
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
323334

324-
if isinstance(mod, FakeQuantizedLinear):
325-
return mod.to_linear()
326-
elif isinstance(mod, FakeQuantizedEmbedding):
327-
return mod.to_embedding()
328-
else:
329-
return mod
330335

331-
return _remove_fake_quantize
336+
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
337+
def _from_intx_quantization_aware_training_transform(
338+
mod: torch.nn.Module,
339+
config: FromIntXQuantizationAwareTrainingConfig,
340+
) -> torch.nn.Module:
341+
"""
342+
If the given module is a fake quantized module, return the original
343+
corresponding version of the module without fake quantization.
344+
"""
345+
from .embedding import FakeQuantizedEmbedding
346+
from .linear import FakeQuantizedLinear
347+
348+
if isinstance(mod, FakeQuantizedLinear):
349+
return mod.to_linear()
350+
elif isinstance(mod, FakeQuantizedEmbedding):
351+
return mod.to_embedding()
352+
else:
353+
return mod
332354

333355

334356
class ComposableQATQuantizer(TwoStepQuantizer):

0 commit comments

Comments
 (0)