Skip to content

Commit

Permalink
[wip] configs configs configs!
Browse files Browse the repository at this point in the history
Summary:

POC for:

* decoupling configuration from transformation
* stop passing obscure stateful callables around
* enable printing of configuration
* reduce amount of context switching to navigate the logic from `quantize_` to
  quantizing a single module

TODO more polish before wider discussion.

Test Plan:

```
pytest test/quantization/test_quant_api.py -s -x -k test_int4_weight_only_numerics
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_standalone
pytest test/quantization/test_qat.py -s -x -k test_quantize_api_convert_path
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Jan 21, 2025
1 parent 32d9b0b commit 997f715
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 113 deletions.
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api(self):
def test_quantize_api_standalone(self):
"""
Test that the following:
Expand Down
26 changes: 26 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
Expand Down Expand Up @@ -761,6 +762,31 @@ def reset_memory():
assert param.is_cuda
self.assertLess(memory_streaming, memory_baseline)

def test_int4_weight_only_numerics(self):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
# TODO(before land) skip on cpu-only
# TODO(before land) support other inference techniques?

# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
# TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_int4_wo = copy.deepcopy(m_ref)

# quantize
quantize_(m_int4_wo, int4_weight_only())

with torch.no_grad():
y_ref = m_ref(x)
y_int4_wo = m_int4_wo(x)

sqnr = compute_error(y_ref, y_int4_wo)
assert sqnr >= 20, f"SQNR {sqnr} is too low"


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
Empty file added torchao/core/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import abc


# directory location for this might need more polish
class AOBaseWorkflowConfig(abc.ABC):
"""
If a workflow config inherits from this then `quantize_` knows
what to do with it.
TODO write a better docblock.
"""

pass
15 changes: 15 additions & 0 deletions torchao/quantization/_transform_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
from typing import Dict, Callable
from torchao.core.config import AOBaseWorkflowConfig

_QUANTIZE_CONFIG_HANDLER: Dict[
AOBaseWorkflowConfig,
Callable[[torch.nn.Module, AOBaseWorkflowConfig], torch.nn.Module],
] = {}


def register_quantize_module_handler(config_type):
def decorator(func):
_QUANTIZE_CONFIG_HANDLER[config_type] = func

return decorator
112 changes: 67 additions & 45 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

import torch

from torchao.core.config import AOBaseWorkflowConfig
from torchao.quantization.granularity import (
Granularity,
PerAxis,
PerGroup,
PerToken,
)
from torchao.quantization._transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.quant_primitives import (
_SUB_BYTE_INT_BOUNDS,
_SUB_BYTE_UINT_BOUNDS,
Expand Down Expand Up @@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any):
super().__setattr__(name, value)


def intx_quantization_aware_training(
activation_config: Optional[FakeQuantizeConfig] = None,
weight_config: Optional[FakeQuantizeConfig] = None,
) -> Callable:
@dataclass
class IntXQuantizationAwareTrainingWorkflowConfig(AOBaseWorkflowConfig):
activation_config: Optional[FakeQuantizeConfig] = None
weight_config: Optional[FakeQuantizeConfig] = None


# for BC
intx_quantization_aware_training = IntXQuantizationAwareTrainingWorkflowConfig


@register_quantize_module_handler(IntXQuantizationAwareTrainingWorkflowConfig)
def _intx_quantization_aware_training_transform(
module: torch.nn.Module,
config: IntXQuantizationAwareTrainingWorkflowConfig,
) -> torch.nn.Module:
"""
Return a function that applies fake quantization to a `torch.nn.Module`.
THIS IS NOT A PUBLIC API - any usage of this outside of torchao
can break at any time.
Apply fake quantization to a `torch.nn.Module`.
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
Example usage::
Expand All @@ -267,37 +285,32 @@ def intx_quantization_aware_training(
`torch.nn.Embedding` with an activation config, then we will raise
ValueError as these are not supported.
"""

def _insert_fake_quantize(mod: torch.nn.Module):
"""
Swap the given module with its corresponding fake quantized version.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

mod = module
activation_config = config.activation_config
weight_config = config.weight_config

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Module of type '%s' does not have QAT support" % type(mod)
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))

return _insert_fake_quantize


def from_intx_quantization_aware_training() -> Callable:
class FromIntXQuantizationAwareTrainingWorkflowConfig(AOBaseWorkflowConfig):
"""
Return a function that converts a model with fake quantized modules,
Object that knows how to convert a model with fake quantized modules,
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
back to model with the original, corresponding modules without
Expand All @@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable:
)
"""

def _remove_fake_quantize(mod: torch.nn.Module):
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear
pass


# for BC
from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingWorkflowConfig

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod

return _remove_fake_quantize
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingWorkflowConfig)
def _from_intx_quantization_aware_training_transform(
mod: torch.nn.Module,
config: FromIntXQuantizationAwareTrainingWorkflowConfig,
) -> torch.nn.Module:
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod


class ComposableQATQuantizer(TwoStepQuantizer):
Expand Down
Loading

0 comments on commit 997f715

Please sign in to comment.