Skip to content

Commit 24114ce

Browse files
committed
Update
[ghstack-poisoned]
1 parent 32d9b0b commit 24114ce

File tree

7 files changed

+249
-114
lines changed

7 files changed

+249
-114
lines changed

Diff for: test/quantization/test_qat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
11851185
@unittest.skipIf(
11861186
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
11871187
)
1188-
def test_quantize_api(self):
1188+
def test_quantize_api_standalone(self):
11891189
"""
11901190
Test that the following:
11911191

Diff for: test/quantization/test_quant_api.py

+26
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
Int4WeightOnlyQuantizedLinearWeight,
4141
Int8WeightOnlyQuantizedLinearWeight,
4242
)
43+
from torchao.quantization.utils import compute_error
4344
from torchao.utils import (
4445
TORCH_VERSION_AT_LEAST_2_3,
4546
TORCH_VERSION_AT_LEAST_2_4,
@@ -761,6 +762,31 @@ def reset_memory():
761762
assert param.is_cuda
762763
self.assertLess(memory_streaming, memory_baseline)
763764

765+
def test_int4_weight_only_numerics(self):
766+
"""
767+
Simple test of e2e int4_weight_only workflow, comparing numerics
768+
to a bfloat16 baseline.
769+
"""
770+
# TODO(before land) skip on cpu-only
771+
# TODO(before land) support other inference techniques?
772+
773+
# set up inputs
774+
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
775+
# TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
776+
# is that expected?
777+
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
778+
m_int4_wo = copy.deepcopy(m_ref)
779+
780+
# quantize
781+
quantize_(m_int4_wo, int4_weight_only())
782+
783+
with torch.no_grad():
784+
y_ref = m_ref(x)
785+
y_int4_wo = m_int4_wo(x)
786+
787+
sqnr = compute_error(y_ref, y_int4_wo)
788+
assert sqnr >= 20, f"SQNR {sqnr} is too low"
789+
764790

765791
class TestMultiTensorFlow(TestCase):
766792
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")

Diff for: torchao/core/__init__.py

Whitespace-only changes.

Diff for: torchao/core/config.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import abc
2+
3+
4+
# directory location for this might need more polish
5+
class AOBaseWorkflowConfig(abc.ABC):
6+
"""
7+
If a workflow config inherits from this then `quantize_` knows
8+
what to do with it.
9+
10+
TODO write a better docblock.
11+
"""
12+
13+
pass

Diff for: torchao/quantization/_transform_module.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from typing import Callable, Dict
2+
3+
import torch
4+
5+
from torchao.core.config import AOBaseWorkflowConfig
6+
7+
_QUANTIZE_CONFIG_HANDLER: Dict[
8+
AOBaseWorkflowConfig,
9+
Callable[[torch.nn.Module, AOBaseWorkflowConfig], torch.nn.Module],
10+
] = {}
11+
12+
13+
def register_quantize_module_handler(config_type):
14+
def decorator(func):
15+
_QUANTIZE_CONFIG_HANDLER[config_type] = func
16+
17+
return decorator

Diff for: torchao/quantization/qat/api.py

+68-46
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, Callable, List, Optional, Union
8+
from typing import Any, List, Optional, Union
99

1010
import torch
1111

12+
from torchao.core.config import AOBaseWorkflowConfig
13+
from torchao.quantization._transform_module import (
14+
register_quantize_module_handler,
15+
)
1216
from torchao.quantization.granularity import (
1317
Granularity,
1418
PerAxis,
@@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any):
239243
super().__setattr__(name, value)
240244

241245

242-
def intx_quantization_aware_training(
243-
activation_config: Optional[FakeQuantizeConfig] = None,
244-
weight_config: Optional[FakeQuantizeConfig] = None,
245-
) -> Callable:
246+
@dataclass
247+
class IntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig):
248+
activation_config: Optional[FakeQuantizeConfig] = None
249+
weight_config: Optional[FakeQuantizeConfig] = None
250+
251+
252+
# for BC
253+
intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
254+
255+
256+
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
257+
def _intx_quantization_aware_training_transform(
258+
module: torch.nn.Module,
259+
config: IntXQuantizationAwareTrainingConfig,
260+
) -> torch.nn.Module:
246261
"""
247-
Return a function that applies fake quantization to a `torch.nn.Module`.
262+
THIS IS NOT A PUBLIC API - any usage of this outside of torchao
263+
can break at any time.
264+
265+
Apply fake quantization to a `torch.nn.Module`.
248266
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
249267
250268
Example usage::
@@ -267,37 +285,32 @@ def intx_quantization_aware_training(
267285
`torch.nn.Embedding` with an activation config, then we will raise
268286
ValueError as these are not supported.
269287
"""
270-
271-
def _insert_fake_quantize(mod: torch.nn.Module):
272-
"""
273-
Swap the given module with its corresponding fake quantized version.
274-
"""
275-
from .embedding import FakeQuantizedEmbedding
276-
from .linear import FakeQuantizedLinear
277-
278-
if isinstance(mod, torch.nn.Linear):
279-
return FakeQuantizedLinear.from_linear(
280-
mod,
281-
activation_config,
282-
weight_config,
283-
)
284-
elif isinstance(mod, torch.nn.Embedding):
285-
if activation_config is not None:
286-
raise ValueError(
287-
"Activation fake quantization is not supported for embedding"
288-
)
289-
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
290-
else:
288+
from .embedding import FakeQuantizedEmbedding
289+
from .linear import FakeQuantizedLinear
290+
291+
mod = module
292+
activation_config = config.activation_config
293+
weight_config = config.weight_config
294+
295+
if isinstance(mod, torch.nn.Linear):
296+
return FakeQuantizedLinear.from_linear(
297+
mod,
298+
activation_config,
299+
weight_config,
300+
)
301+
elif isinstance(mod, torch.nn.Embedding):
302+
if activation_config is not None:
291303
raise ValueError(
292-
"Module of type '%s' does not have QAT support" % type(mod)
304+
"Activation fake quantization is not supported for embedding"
293305
)
306+
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
307+
else:
308+
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
294309

295-
return _insert_fake_quantize
296310

297-
298-
def from_intx_quantization_aware_training() -> Callable:
311+
class FromIntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig):
299312
"""
300-
Return a function that converts a model with fake quantized modules,
313+
Object that knows how to convert a model with fake quantized modules,
301314
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
302315
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
303316
back to model with the original, corresponding modules without
@@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable:
313326
)
314327
"""
315328

316-
def _remove_fake_quantize(mod: torch.nn.Module):
317-
"""
318-
If the given module is a fake quantized module, return the original
319-
corresponding version of the module without fake quantization.
320-
"""
321-
from .embedding import FakeQuantizedEmbedding
322-
from .linear import FakeQuantizedLinear
329+
pass
330+
331+
332+
# for BC
333+
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
323334

324-
if isinstance(mod, FakeQuantizedLinear):
325-
return mod.to_linear()
326-
elif isinstance(mod, FakeQuantizedEmbedding):
327-
return mod.to_embedding()
328-
else:
329-
return mod
330335

331-
return _remove_fake_quantize
336+
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
337+
def _from_intx_quantization_aware_training_transform(
338+
mod: torch.nn.Module,
339+
config: FromIntXQuantizationAwareTrainingConfig,
340+
) -> torch.nn.Module:
341+
"""
342+
If the given module is a fake quantized module, return the original
343+
corresponding version of the module without fake quantization.
344+
"""
345+
from .embedding import FakeQuantizedEmbedding
346+
from .linear import FakeQuantizedLinear
347+
348+
if isinstance(mod, FakeQuantizedLinear):
349+
return mod.to_linear()
350+
elif isinstance(mod, FakeQuantizedEmbedding):
351+
return mod.to_embedding()
352+
else:
353+
return mod
332354

333355

334356
class ComposableQATQuantizer(TwoStepQuantizer):

0 commit comments

Comments
 (0)