From 24114cebb3fd77737185b1e30bef050283c51478 Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Wed, 22 Jan 2025 08:49:11 -0800
Subject: [PATCH 01/10] Update

[ghstack-poisoned]
---
 test/quantization/test_qat.py             |   2 +-
 test/quantization/test_quant_api.py       |  26 +++
 torchao/core/__init__.py                  |   0
 torchao/core/config.py                    |  13 ++
 torchao/quantization/_transform_module.py |  17 ++
 torchao/quantization/qat/api.py           | 114 +++++++------
 torchao/quantization/quant_api.py         | 191 ++++++++++++++--------
 7 files changed, 249 insertions(+), 114 deletions(-)
 create mode 100644 torchao/core/__init__.py
 create mode 100644 torchao/core/config.py
 create mode 100644 torchao/quantization/_transform_module.py

diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py
index 8a78b8b387..82324394a8 100644
--- a/test/quantization/test_qat.py
+++ b/test/quantization/test_qat.py
@@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
     @unittest.skipIf(
         not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
     )
-    def test_quantize_api(self):
+    def test_quantize_api_standalone(self):
         """
         Test that the following:
 
diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py
index 177c357047..ca2cbf08ec 100644
--- a/test/quantization/test_quant_api.py
+++ b/test/quantization/test_quant_api.py
@@ -40,6 +40,7 @@
     Int4WeightOnlyQuantizedLinearWeight,
     Int8WeightOnlyQuantizedLinearWeight,
 )
+from torchao.quantization.utils import compute_error
 from torchao.utils import (
     TORCH_VERSION_AT_LEAST_2_3,
     TORCH_VERSION_AT_LEAST_2_4,
@@ -761,6 +762,31 @@ def reset_memory():
             assert param.is_cuda
         self.assertLess(memory_streaming, memory_baseline)
 
+    def test_int4_weight_only_numerics(self):
+        """
+        Simple test of e2e int4_weight_only workflow, comparing numerics
+        to a bfloat16 baseline.
+        """
+        # TODO(before land) skip on cpu-only
+        # TODO(before land) support other inference techniques?
+
+        # set up inputs
+        x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
+        # TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
+        # is that expected?
+        m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
+        m_int4_wo = copy.deepcopy(m_ref)
+
+        # quantize
+        quantize_(m_int4_wo, int4_weight_only())
+
+        with torch.no_grad():
+            y_ref = m_ref(x)
+            y_int4_wo = m_int4_wo(x)
+
+        sqnr = compute_error(y_ref, y_int4_wo)
+        assert sqnr >= 20, f"SQNR {sqnr} is too low"
+
 
 class TestMultiTensorFlow(TestCase):
     @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
diff --git a/torchao/core/__init__.py b/torchao/core/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/torchao/core/config.py b/torchao/core/config.py
new file mode 100644
index 0000000000..fbc1216212
--- /dev/null
+++ b/torchao/core/config.py
@@ -0,0 +1,13 @@
+import abc
+
+
+# directory location for this might need more polish
+class AOBaseWorkflowConfig(abc.ABC):
+    """
+    If a workflow config inherits from this then `quantize_` knows
+    what to do with it.
+
+    TODO write a better docblock.
+    """
+
+    pass
diff --git a/torchao/quantization/_transform_module.py b/torchao/quantization/_transform_module.py
new file mode 100644
index 0000000000..f14e79b5a9
--- /dev/null
+++ b/torchao/quantization/_transform_module.py
@@ -0,0 +1,17 @@
+from typing import Callable, Dict
+
+import torch
+
+from torchao.core.config import AOBaseWorkflowConfig
+
+_QUANTIZE_CONFIG_HANDLER: Dict[
+    AOBaseWorkflowConfig,
+    Callable[[torch.nn.Module, AOBaseWorkflowConfig], torch.nn.Module],
+] = {}
+
+
+def register_quantize_module_handler(config_type):
+    def decorator(func):
+        _QUANTIZE_CONFIG_HANDLER[config_type] = func
+
+    return decorator
diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py
index cd3813291f..6356ee1600 100644
--- a/torchao/quantization/qat/api.py
+++ b/torchao/quantization/qat/api.py
@@ -5,10 +5,14 @@
 # LICENSE file in the root directory of this source tree.
 
 from dataclasses import dataclass
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, List, Optional, Union
 
 import torch
 
+from torchao.core.config import AOBaseWorkflowConfig
+from torchao.quantization._transform_module import (
+    register_quantize_module_handler,
+)
 from torchao.quantization.granularity import (
     Granularity,
     PerAxis,
@@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any):
             super().__setattr__(name, value)
 
 
-def intx_quantization_aware_training(
-    activation_config: Optional[FakeQuantizeConfig] = None,
-    weight_config: Optional[FakeQuantizeConfig] = None,
-) -> Callable:
+@dataclass
+class IntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig):
+    activation_config: Optional[FakeQuantizeConfig] = None
+    weight_config: Optional[FakeQuantizeConfig] = None
+
+
+# for BC
+intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
+
+
+@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
+def _intx_quantization_aware_training_transform(
+    module: torch.nn.Module,
+    config: IntXQuantizationAwareTrainingConfig,
+) -> torch.nn.Module:
     """
-    Return a function that applies fake quantization to a `torch.nn.Module`.
+    THIS IS NOT A PUBLIC API - any usage of this outside of torchao
+    can break at any time.
+
+    Apply fake quantization to a `torch.nn.Module`.
     to be used with :func:`~torchao.quantization.quant_api.quantize_`.
 
     Example usage::
@@ -267,37 +285,32 @@ def intx_quantization_aware_training(
     `torch.nn.Embedding` with an activation config, then we will raise
     ValueError as these are not supported.
     """
-
-    def _insert_fake_quantize(mod: torch.nn.Module):
-        """
-        Swap the given module with its corresponding fake quantized version.
-        """
-        from .embedding import FakeQuantizedEmbedding
-        from .linear import FakeQuantizedLinear
-
-        if isinstance(mod, torch.nn.Linear):
-            return FakeQuantizedLinear.from_linear(
-                mod,
-                activation_config,
-                weight_config,
-            )
-        elif isinstance(mod, torch.nn.Embedding):
-            if activation_config is not None:
-                raise ValueError(
-                    "Activation fake quantization is not supported for embedding"
-                )
-            return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
-        else:
+    from .embedding import FakeQuantizedEmbedding
+    from .linear import FakeQuantizedLinear
+
+    mod = module
+    activation_config = config.activation_config
+    weight_config = config.weight_config
+
+    if isinstance(mod, torch.nn.Linear):
+        return FakeQuantizedLinear.from_linear(
+            mod,
+            activation_config,
+            weight_config,
+        )
+    elif isinstance(mod, torch.nn.Embedding):
+        if activation_config is not None:
             raise ValueError(
-                "Module of type '%s' does not have QAT support" % type(mod)
+                "Activation fake quantization is not supported for embedding"
             )
+        return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
+    else:
+        raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
 
-    return _insert_fake_quantize
 
-
-def from_intx_quantization_aware_training() -> Callable:
+class FromIntXQuantizationAwareTrainingConfig(AOBaseWorkflowConfig):
     """
-    Return a function that converts a model with fake quantized modules,
+    Object that knows how to convert a model with fake quantized modules,
     such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
     and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
     back to model with the original, corresponding modules without
@@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable:
         )
     """
 
-    def _remove_fake_quantize(mod: torch.nn.Module):
-        """
-        If the given module is a fake quantized module, return the original
-        corresponding version of the module without fake quantization.
-        """
-        from .embedding import FakeQuantizedEmbedding
-        from .linear import FakeQuantizedLinear
+    pass
+
+
+# for BC
+from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
 
-        if isinstance(mod, FakeQuantizedLinear):
-            return mod.to_linear()
-        elif isinstance(mod, FakeQuantizedEmbedding):
-            return mod.to_embedding()
-        else:
-            return mod
 
-    return _remove_fake_quantize
+@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
+def _from_intx_quantization_aware_training_transform(
+    mod: torch.nn.Module,
+    config: FromIntXQuantizationAwareTrainingConfig,
+) -> torch.nn.Module:
+    """
+    If the given module is a fake quantized module, return the original
+    corresponding version of the module without fake quantization.
+    """
+    from .embedding import FakeQuantizedEmbedding
+    from .linear import FakeQuantizedLinear
+
+    if isinstance(mod, FakeQuantizedLinear):
+        return mod.to_linear()
+    elif isinstance(mod, FakeQuantizedEmbedding):
+        return mod.to_embedding()
+    else:
+        return mod
 
 
 class ComposableQATQuantizer(TwoStepQuantizer):
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index b2eff196fd..450563be36 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -18,13 +18,15 @@
 import logging
 import types
 import warnings
-from typing import Callable, Optional, Tuple, Union
+from dataclasses import dataclass
+from typing import Any, Callable, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
 import torch.nn.utils.parametrize as parametrize
 
 import torchao
+from torchao.core.config import AOBaseWorkflowConfig
 from torchao.dtypes import (
     AffineQuantizedTensor,
     CutlassInt4PackedLayout,
@@ -43,6 +45,10 @@
 )
 from torchao.float8.float8_linear import Float8Linear
 from torchao.float8.inference import Float8MMConfig
+from torchao.quantization._transform_module import (
+    _QUANTIZE_CONFIG_HANDLER,
+    register_quantize_module_handler,
+)
 from torchao.quantization.linear_activation_weight_observed_tensor import (
     LinearActivationWeightObservedTensor,
 )
@@ -117,7 +123,6 @@
     "Int8DynActInt4WeightGPTQQuantizer",
 ]
 
-# update according to the support matrix
 LAYOUT_TO_ZERO_POINT_DOMAIN = {
     TensorCoreTiledLayout: [ZeroPointDomain.FLOAT],
     MarlinSparseLayout: [ZeroPointDomain.INT],
@@ -228,6 +233,7 @@ def _replace_with_custom_fn_if_matches_filter(
     filter_fn,
     cur_fqn="",
     device=None,
+    extra_args: Optional[Tuple[Any, ...]] = (),
 ) -> None:
     """
     Recursively replaces each child module in `model` with the result of `replacement_fn(child)`
@@ -239,6 +245,7 @@ def _replace_with_custom_fn_if_matches_filter(
         filter_fn (Callable[[torch.nn.Module], bool]): The filter function to determine which modules to replace.
         cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "".
         device (device, optional): Device to move the model to before applying `filter_fn`. Defaults to None.
+        extra_args (Tuple[Any, ...], optional): optional extra args to pass to `replacement_fn`.
 
     Returns:
         None
@@ -252,12 +259,17 @@ def _replace_with_custom_fn_if_matches_filter(
     if filter_fn(model, cur_fqn[:-1]):
         if device is not None:
             model.to(device=device)  # move to device before quantization
-        model = replacement_fn(model)
+        model = replacement_fn(model, *extra_args)
         return model
     else:
         for name, child in model.named_children():
             new_child = _replace_with_custom_fn_if_matches_filter(
-                child, replacement_fn, filter_fn, f"{cur_fqn}{name}.", device
+                child,
+                replacement_fn,
+                filter_fn,
+                f"{cur_fqn}{name}.",
+                device,
+                extra_args,
             )
             if new_child is not child:
                 setattr(model, name, new_child)
@@ -468,7 +480,10 @@ def insert_subclass(lin):
 
 def quantize_(
     model: torch.nn.Module,
-    apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
+    # apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
+    apply_tensor_subclass: Union[
+        Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig
+    ],
     filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
     set_inductor_config: bool = True,
     device: Optional[torch.types.Device] = None,
@@ -530,12 +545,33 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
     if set_inductor_config:
         torchao.quantization.utils.recommended_inductor_config_setter()
 
-    _replace_with_custom_fn_if_matches_filter(
-        model,
-        apply_tensor_subclass,
-        _is_linear if filter_fn is None else filter_fn,
-        device=device,
-    )
+    if isinstance(apply_tensor_subclass, AOBaseWorkflowConfig):
+        # new behavior
+
+        # make the variable name make sense
+        config = apply_tensor_subclass
+        handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
+
+        # for each linear in the model, apply the transform if filtering passes
+        # key difference from old is that `config_with_transform` is easily
+        # inspectable
+        _replace_with_custom_fn_if_matches_filter(
+            model,
+            handler,
+            _is_linear if filter_fn is None else filter_fn,
+            device=device,
+            extra_args=(config,),
+        )
+
+    else:
+        # old behavior, for now keep for BC purposes
+        # TODO(after discussion): flesh the BC story out more
+        _replace_with_custom_fn_if_matches_filter(
+            model,
+            apply_tensor_subclass,
+            _is_linear if filter_fn is None else filter_fn,
+            device=device,
+        )
 
 
 def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
@@ -684,14 +720,10 @@ def gemlite_uintx_weight_only(
     return _get_linear_subclass_inserter(apply_fn)
 
 
-def int4_weight_only(
-    group_size=128,
-    layout=TensorCoreTiledLayout(inner_k_tiles=8),
-    use_hqq=False,
-    zero_point_domain=None,
-):
+@dataclass
+class Int4WeightOnlyConfig(AOBaseWorkflowConfig):
     """
-    Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
+    Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using
     "tensor_core_tiled" layout for speedup with tinygemm kernel
 
     Note:
@@ -711,59 +743,84 @@ def int4_weight_only(
         `zero_point_domain`: data type of zeros points, choices are [None(then the value is determined by the layout), ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
     """
 
-    def apply_int4_weight_only_quant(weight):
-        if weight.shape[-1] % group_size != 0:
-            logger.info(
-                f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
-            )
-            return weight
+    group_size: int = 128
+    layout: Optional[TensorCoreTiledLayout] = TensorCoreTiledLayout(inner_k_tiles=8)
+    use_hqq: bool = False
+    zero_point_domain: Optional[ZeroPointDomain] = None
 
-        mapping_type = MappingType.ASYMMETRIC
-        block_size = (1, group_size)
-        target_dtype = torch.int32
-        quant_min = 0
-        quant_max = 15
-        eps = 1e-6
-        preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
-        zero_point_dtype = torch.bfloat16
-
-        nonlocal zero_point_domain
-        assert (
-            type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys()
-        ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
-        if zero_point_domain is None:
-            # the first value is the default one
-            zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
-        else:
-            assert (
-                zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)]
-            ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"
-
-        # Sparse Marlin only supports symmetric quantization.
-        # NOTE: If we start having lots of layouts that require different configurations,
-        # we should consider moving this logic somewhere else.
-        if isinstance(layout, MarlinSparseLayout):
-            mapping_type = MappingType.SYMMETRIC
-            assert (
-                group_size == 128 or group_size == weight.shape[-1]
-            ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}"
 
-        return to_affine_quantized_intx(
-            weight,
-            mapping_type,
-            block_size,
-            target_dtype,
-            quant_min,
-            quant_max,
-            eps,
-            zero_point_dtype=zero_point_dtype,
-            preserve_zero=preserve_zero,
-            zero_point_domain=zero_point_domain,
-            _layout=layout,
-            use_hqq=use_hqq,
+# for BC
+# TODO maybe change other callsites
+int4_weight_only = Int4WeightOnlyConfig
+
+
+@register_quantize_module_handler(Int4WeightOnlyConfig)
+def _int4_weight_only_transform(
+    module: torch.nn.Module, config: Int4WeightOnlyConfig
+) -> torch.nn.Module:
+    # TODO(future PR): perhaps move this logic to a different file, to keep the API
+    # file clean of implementation details
+
+    # for now, make these local variables to allow the rest of the function
+    # to be a direct copy-paste
+    weight = module.weight
+    group_size = config.group_size
+    layout = config.layout
+    use_hqq = config.use_hqq
+    zero_point_domain = config.zero_point_domain
+
+    if weight.shape[-1] % group_size != 0:
+        logger.info(
+            f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
         )
+        return weight
 
-    return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
+    mapping_type = MappingType.ASYMMETRIC
+    block_size = (1, group_size)
+    target_dtype = torch.int32
+    quant_min = 0
+    quant_max = 15
+    eps = 1e-6
+    preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
+    zero_point_dtype = torch.bfloat16
+
+    # nonlocal zero_point_domain
+    assert (
+        type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys()
+    ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
+    if zero_point_domain is None:
+        # the first value is the default one
+        zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
+    else:
+        assert (
+            zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)]
+        ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"
+
+    # Sparse Marlin only supports symmetric quantization.
+    # NOTE: If we start having lots of layouts that require different configurations,
+    # we should consider moving this logic somewhere else.
+    if isinstance(layout, MarlinSparseLayout):
+        mapping_type = MappingType.SYMMETRIC
+        assert (
+            group_size == 128 or group_size == weight.shape[-1]
+        ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}"
+
+    new_weight = to_affine_quantized_intx(
+        weight,
+        mapping_type,
+        block_size,
+        target_dtype,
+        quant_min,
+        quant_max,
+        eps,
+        zero_point_dtype=zero_point_dtype,
+        preserve_zero=preserve_zero,
+        zero_point_domain=zero_point_domain,
+        _layout=layout,
+        use_hqq=use_hqq,
+    )
+    module.weight = torch.nn.Parameter(new_weight)
+    return module
 
 
 def int8_weight_only(group_size=None):

From 5b9d876d7ea41db7964278c6b59b27e6b79645fb Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Wed, 22 Jan 2025 10:08:28 -0800
Subject: [PATCH 02/10] Update

[ghstack-poisoned]
---
 test/dtypes/test_affine_quantized.py |  7 +++-
 test/quantization/test_quant_api.py  |  7 ++--
 torchao/quantization/quant_api.py    | 58 ++++++++--------------------
 3 files changed, 25 insertions(+), 47 deletions(-)

diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py
index f08ba7aa72..1b4bf58cf9 100644
--- a/test/dtypes/test_affine_quantized.py
+++ b/test/dtypes/test_affine_quantized.py
@@ -8,6 +8,7 @@
     run_tests,
 )
 
+from torchao.core.config import AOBaseWorkflowConfig
 from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
 from torchao.quantization import (
     float8_weight_only,
@@ -15,6 +16,7 @@
     int8_dynamic_activation_int4_weight,
     int8_dynamic_activation_int8_weight,
     int8_weight_only,
+    quantize_,
 )
 from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
 from torchao.utils import (
@@ -186,7 +188,10 @@ def test_flatten_unflatten(self, device, dtype):
         apply_quant_list = get_quantization_functions(False, True, device)
         for apply_quant in apply_quant_list:
             linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
-            ql = apply_quant(linear)
+            if isinstance(apply_quant, AOBaseWorkflowConfig):
+                quantize_(linear, apply_quant) 
+            else:
+                ql = apply_quant(linear)
             lp_tensor = ql.weight
             tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
             tensor_data_dict = {
diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py
index ca2cbf08ec..80536bfac9 100644
--- a/test/quantization/test_quant_api.py
+++ b/test/quantization/test_quant_api.py
@@ -762,17 +762,16 @@ def reset_memory():
             assert param.is_cuda
         self.assertLess(memory_streaming, memory_baseline)
 
+    @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
+    @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
     def test_int4_weight_only_numerics(self):
         """
         Simple test of e2e int4_weight_only workflow, comparing numerics
         to a bfloat16 baseline.
         """
-        # TODO(before land) skip on cpu-only
-        # TODO(before land) support other inference techniques?
-
         # set up inputs
         x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
-        # TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
+        # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
         # is that expected?
         m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
         m_int4_wo = copy.deepcopy(m_ref)
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index 450563be36..e36bc7d8e3 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -262,7 +262,8 @@ def _replace_with_custom_fn_if_matches_filter(
         model = replacement_fn(model, *extra_args)
         return model
     else:
-        for name, child in model.named_children():
+        named_children_list = list(model.named_children())
+        for name, child in named_children_list:
             new_child = _replace_with_custom_fn_if_matches_filter(
                 child,
                 replacement_fn,
@@ -480,20 +481,19 @@ def insert_subclass(lin):
 
 def quantize_(
     model: torch.nn.Module,
-    # apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
-    apply_tensor_subclass: Union[
-        Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig
+    config: Union[
+        AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]
     ],
     filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
     set_inductor_config: bool = True,
     device: Optional[torch.types.Device] = None,
 ):
-    """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
+    """Convert the weight of linear modules in the model with `config`, model is modified inplace
 
     Args:
         model (torch.nn.Module): input model
-        apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor)
-        filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
+        config (Union[AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release.
+        filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on
         the weight of the module
         set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
         device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`.
@@ -505,7 +505,7 @@ def quantize_(
         import torch.nn as nn
         from torchao import quantize_
 
-        # 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to
+        # quantize with some predefined `config` method that corresponds to
         # optimized execution paths or kernels (e.g. int4 tinygemm kernel)
         # also customizable with arguments
         # currently options are
@@ -518,43 +518,13 @@ def quantize_(
         m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
         quantize_(m, int4_weight_only(group_size=32))
 
-        # 2. write your own new apply_tensor_subclass
-        # You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
-        # on weight
-
-        from torchao.dtypes import to_affine_quantized_intx
-
-        # weight only uint4 asymmetric groupwise quantization
-        groupsize = 32
-        apply_weight_quant = lambda x: to_affine_quantized_intx(
-          x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
-          zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")
-
-        def apply_weight_quant_to_linear(linear):
-            linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False)
-            return linear
-
-        # apply to modules under block0 submodule
-        def filter_fn(module: nn.Module, fqn: str) -> bool:
-            return isinstance(module, nn.Linear)
-
-        m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
-        quantize_(m, apply_weight_quant_to_linear, filter_fn)
-
     """
     if set_inductor_config:
         torchao.quantization.utils.recommended_inductor_config_setter()
 
-    if isinstance(apply_tensor_subclass, AOBaseWorkflowConfig):
-        # new behavior
-
-        # make the variable name make sense
-        config = apply_tensor_subclass
+    if isinstance(config, AOBaseWorkflowConfig):
         handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
-
         # for each linear in the model, apply the transform if filtering passes
-        # key difference from old is that `config_with_transform` is easily
-        # inspectable
         _replace_with_custom_fn_if_matches_filter(
             model,
             handler,
@@ -564,8 +534,12 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
         )
 
     else:
-        # old behavior, for now keep for BC purposes
-        # TODO(after discussion): flesh the BC story out more
+        # old behavior, keep to avoid breaking BC
+        warnings.warn("""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""")
+
+        # make the variable name make sense
+        apply_tensor_subclass = config
+
         _replace_with_custom_fn_if_matches_filter(
             model,
             apply_tensor_subclass,
@@ -773,7 +747,7 @@ def _int4_weight_only_transform(
         logger.info(
             f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
         )
-        return weight
+        return module
 
     mapping_type = MappingType.ASYMMETRIC
     block_size = (1, group_size)

From 1cea42fbd49f534c697471f9c35c424768607985 Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Wed, 22 Jan 2025 10:39:15 -0800
Subject: [PATCH 03/10] Update

[ghstack-poisoned]
---
 test/dtypes/test_affine_quantized.py | 2 +-
 torchao/quantization/quant_api.py    | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py
index 1b4bf58cf9..9ef26026e2 100644
--- a/test/dtypes/test_affine_quantized.py
+++ b/test/dtypes/test_affine_quantized.py
@@ -189,7 +189,7 @@ def test_flatten_unflatten(self, device, dtype):
         for apply_quant in apply_quant_list:
             linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
             if isinstance(apply_quant, AOBaseWorkflowConfig):
-                quantize_(linear, apply_quant) 
+                quantize_(linear, apply_quant)
             else:
                 ql = apply_quant(linear)
             lp_tensor = ql.weight
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index e36bc7d8e3..efda1dbb23 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -481,9 +481,7 @@ def insert_subclass(lin):
 
 def quantize_(
     model: torch.nn.Module,
-    config: Union[
-        AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]
-    ],
+    config: Union[AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]],
     filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
     set_inductor_config: bool = True,
     device: Optional[torch.types.Device] = None,
@@ -535,7 +533,9 @@ def quantize_(
 
     else:
         # old behavior, keep to avoid breaking BC
-        warnings.warn("""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""")
+        warnings.warn(
+            """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead."""
+        )
 
         # make the variable name make sense
         apply_tensor_subclass = config

From 138883b4f40073517c1a5a71dd87c00d33c87d43 Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Wed, 22 Jan 2025 12:44:06 -0800
Subject: [PATCH 04/10] Update

[ghstack-poisoned]
---
 test/dtypes/test_affine_quantized.py | 19 +++++++++++++++----
 test/hqq/test_hqq_affine.py          |  7 ++++---
 torchao/quantization/quant_api.py    |  1 +
 3 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py
index 9ef26026e2..671c676e76 100644
--- a/test/dtypes/test_affine_quantized.py
+++ b/test/dtypes/test_affine_quantized.py
@@ -60,7 +60,8 @@ def get_quantization_functions(
                     )
                 )
 
-    if do_sparse:
+    # TODO(before land): revert this back, added due to lack of cuSparseLt in my env
+    if do_sparse and False:
         base_functions.append(
             int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
         )
@@ -78,7 +79,8 @@ def test_tensor_core_layout_transpose(self):
         t = linear.weight
         shape = t.shape
         apply_int4_weight_only_quant = int4_weight_only(group_size=32)
-        ql = apply_int4_weight_only_quant(linear)
+        quantize_(linear, apply_int4_weight_only_quant)
+        ql = linear
         aqt = ql.weight
         aqt_shape = aqt.shape
         self.assertEqual(aqt_shape, shape)
@@ -97,7 +99,11 @@ def test_tensor_core_layout_transpose(self):
     )
     def test_weights_only(self, apply_quant):
         linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
-        ql = apply_quant(linear)
+        if isinstance(apply_quant, AOBaseWorkflowConfig):
+            quantize_(linear, apply_quant)
+            ql = linear
+        else:
+            ql = apply_quant(linear)
         with tempfile.NamedTemporaryFile() as f:
             torch.save(ql.state_dict(), f)
             f.seek(0)
@@ -173,8 +179,13 @@ def apply_uint6_weight_only_quant(linear):
     @common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
     @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
     def test_print_quantized_module(self, apply_quant):
+        print(apply_quant)
         linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
-        ql = apply_quant(linear)
+        if isinstance(apply_quant, AOBaseWorkflowConfig):
+            quantize_(linear, apply_quant)
+            ql = linear
+        else:
+            ql = apply_quant(linear)
         assert "AffineQuantizedTensor" in str(ql)
 
 
diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py
index 381886d594..096c9d26ba 100644
--- a/test/hqq/test_hqq_affine.py
+++ b/test/hqq/test_hqq_affine.py
@@ -6,6 +6,7 @@
     MappingType,
     ZeroPointDomain,
     int4_weight_only,
+    quantize_,
     uintx_weight_only,
 )
 from torchao.utils import (
@@ -51,9 +52,9 @@ def _eval_hqq(dtype):
     )
     dummy_linear.weight.data = W
     if dtype == torch.uint4:
-        q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(
-            dummy_linear
-        ).weight
+        config = int4_weight_only(group_size=max(block_size), use_hqq=True)
+        quantize_(dummy_linear, config)
+        q_tensor_hqq = dummy_linear.weight
     else:
         q_tensor_hqq = uintx_weight_only(
             dtype, group_size=max(block_size), use_hqq=True
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index efda1dbb23..1c7284a01d 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -794,6 +794,7 @@ def _int4_weight_only_transform(
         use_hqq=use_hqq,
     )
     module.weight = torch.nn.Parameter(new_weight)
+    module.extra_repr = types.MethodType(_linear_extra_repr, module)
     return module
 
 

From ba045ea89316a7a14b92d4849f44e9ff1ad276f5 Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Wed, 22 Jan 2025 12:56:28 -0800
Subject: [PATCH 05/10] Update

[ghstack-poisoned]
---
 test/dtypes/test_affine_quantized.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py
index 671c676e76..2cb87ab133 100644
--- a/test/dtypes/test_affine_quantized.py
+++ b/test/dtypes/test_affine_quantized.py
@@ -60,8 +60,7 @@ def get_quantization_functions(
                     )
                 )
 
-    # TODO(before land): revert this back, added due to lack of cuSparseLt in my env
-    if do_sparse and False:
+    if do_sparse:
         base_functions.append(
             int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
         )

From 94d942606bcea5bad5c36b819d779deaa7c1572b Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Wed, 22 Jan 2025 15:08:47 -0800
Subject: [PATCH 06/10] Update

[ghstack-poisoned]
---
 torchao/quantization/quant_api.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index 1c7284a01d..3401a42ab7 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -793,7 +793,7 @@ def _int4_weight_only_transform(
         _layout=layout,
         use_hqq=use_hqq,
     )
-    module.weight = torch.nn.Parameter(new_weight)
+    module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
     module.extra_repr = types.MethodType(_linear_extra_repr, module)
     return module
 

From 26850dae92bdcf6535fcf30ca4fc21f4074bde44 Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Wed, 5 Feb 2025 13:34:44 -0800
Subject: [PATCH 07/10] Update

[ghstack-poisoned]
---
 torchao/quantization/quant_api.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index c598393b50..a6e8ee8e0b 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -779,6 +779,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
     use_hqq: bool = False
     zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.NONE
 
+
 # for BC
 # TODO maybe change other callsites
 int4_weight_only = Int4WeightOnlyConfig
@@ -812,7 +813,9 @@ def _int4_weight_only_transform(
     quant_max = 15
     eps = 1e-6
     preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
-    zero_point_dtype = torch.bfloat16
+    zero_point_dtype = (
+        weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16
+    )
 
     # nonlocal zero_point_domain
     assert (

From d42a59070b95f85983226d280f2de60a7dbf8735 Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Mon, 10 Feb 2025 15:33:48 -0800
Subject: [PATCH 08/10] Update

[ghstack-poisoned]
---
 test/dtypes/test_affine_quantized.py |  14 +-
 test/quantization/test_quant_api.py  |  25 ++-
 torchao/quantization/quant_api.py    | 246 ++++++++++++++++-----------
 3 files changed, 180 insertions(+), 105 deletions(-)

diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py
index 53ca470b04..d26f1d8e04 100644
--- a/test/dtypes/test_affine_quantized.py
+++ b/test/dtypes/test_affine_quantized.py
@@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant):
     @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
     @common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
     def test_to_device(self, apply_quant):
+        def _apply(module, config_or_subclass_inserter):
+            if isinstance(config_or_subclass_inserter, AOBaseConfig):
+                quantize_(module, config_or_subclass_inserter)
+            else:
+                # TODO(#1690): delete this once config migration is done
+                module = config_or_subclass_inserter(module)
+            return module
+
         linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
-        ql = apply_quant(linear)
+        ql = _apply(linear, apply_quant)
         ql.to("cuda")
 
         linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
-        ql = apply_quant(linear)
+        ql = _apply(linear, apply_quant)
         ql.to(device="cuda")
 
         linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
-        ql = apply_quant(linear)
+        ql = _apply(linear, apply_quant)
         ql.cuda()
 
     @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py
index acd9b50c5a..b9220c2815 100644
--- a/test/quantization/test_quant_api.py
+++ b/test/quantization/test_quant_api.py
@@ -30,6 +30,9 @@
     Quantizer,
     TwoStepQuantizer,
     _replace_with_custom_fn_if_matches_filter,
+    float8_dynamic_activation_float8_weight,
+    float8_static_activation_float8_weight,
+    float8_weight_only,
     int4_weight_only,
     int8_dynamic_activation_int4_weight,
     int8_dynamic_activation_int8_weight,
@@ -784,9 +787,21 @@ def test_int4wo_cpu(self, dtype, x_dim):
             assert "_weight_int4pack_mm_for_cpu" in code[0]
             assert "aten.mm.default" not in code[0]
 
+    # TODO(#1690): move to new config names
     @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
     @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
-    def test_int4_weight_only_numerics(self):
+    @common_utils.parametrize(
+        "config",
+        [
+            int4_weight_only(),
+            float8_weight_only(),
+            float8_dynamic_activation_float8_weight(),
+            float8_static_activation_float8_weight(
+                scale=torch.tensor([1.0], device="cuda")
+            ),
+        ],
+    )
+    def test_workflow_e2e_numerics(self, config):
         """
         Simple test of e2e int4_weight_only workflow, comparing numerics
         to a bfloat16 baseline.
@@ -796,16 +811,16 @@ def test_int4_weight_only_numerics(self):
         # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
         # is that expected?
         m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
-        m_int4_wo = copy.deepcopy(m_ref)
+        m_q = copy.deepcopy(m_ref)
 
         # quantize
-        quantize_(m_int4_wo, int4_weight_only())
+        quantize_(m_q, config)
 
         with torch.no_grad():
             y_ref = m_ref(x)
-            y_int4_wo = m_int4_wo(x)
+            y_q = m_q(x)
 
-        sqnr = compute_error(y_ref, y_int4_wo)
+        sqnr = compute_error(y_ref, y_q)
         assert sqnr >= 20, f"SQNR {sqnr} is too low"
 
 
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index a6e8ee8e0b..01e3a7c029 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
     return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
 
 
-def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
+@dataclass
+class Float8WeightOnlyConfig(AOBaseConfig):
     """
-    Applies float8 weight-only symmetric per-channel quantization to linear layers.
+    Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers.
 
     Args:
         weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
 
     Note:
         The actual matmul will be computed in original precision of the weight tensor.
-
     """
-    from torchao.dtypes import to_affine_quantized_floatx
 
-    def apply_float8wo_quant(weight):
-        block_size = (1, weight.shape[1])
-        return to_affine_quantized_floatx(
-            input_float=weight,
-            block_size=block_size,
-            target_dtype=weight_dtype,
-            scale_dtype=None,
-            _layout=Float8Layout(mm_config=None),
-        )
+    weight_dtype: torch.dtype = torch.float8_e4m3fn
+
+
+# for BC
+float8_weight_only = Float8WeightOnlyConfig
+
+
+@register_quantize_module_handler(Float8WeightOnlyConfig)
+def _float8_weight_only_transform(
+    module: torch.nn.Module, config: Float8WeightOnlyConfig
+) -> torch.nn.Module:
+    from torchao.dtypes import to_affine_quantized_floatx
 
-    return _get_linear_subclass_inserter(apply_float8wo_quant)
+    weight = module.weight
+    block_size = (1, weight.shape[1])
+    new_weight = to_affine_quantized_floatx(
+        input_float=weight,
+        block_size=block_size,
+        target_dtype=config.weight_dtype,
+        scale_dtype=None,
+        _layout=Float8Layout(mm_config=None),
+    )
+    module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
+    module.extra_repr = types.MethodType(_linear_extra_repr, module)
+    return module
 
 
 _fp8_granularities = Union[PerTensor, PerRow]
@@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool:
     return is_compatible
 
 
-def float8_dynamic_activation_float8_weight(
-    activation_dtype: torch.dtype = torch.float8_e4m3fn,
-    weight_dtype: torch.dtype = torch.float8_e4m3fn,
-    granularity: Optional[
-        Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
-    ] = None,
-    mm_config: Optional[Float8MMConfig] = None,
-):
+@dataclass
+class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
     """
-    Applies float8 dynamic symmetric quantization to both activations and weights of linear layers.
+    Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers.
 
     Args:
         activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn.
@@ -1192,56 +1199,75 @@ def float8_dynamic_activation_float8_weight(
         mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
 
     """
-    assert (
-        is_sm_at_least_89() or is_MI300()
-    ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
-    if mm_config is None:
-        mm_config = Float8MMConfig(use_fast_accum=True)
 
-    activation_granularity, weight_granularity = _normalize_granularity(granularity)
+    activation_dtype: torch.dtype = torch.float8_e4m3fn
+    weight_dtype: torch.dtype = torch.float8_e4m3fn
+    granularity: Optional[
+        Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
+    ] = None
+    mm_config: Optional[Float8MMConfig] = None
 
-    def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
-        if not _fp8_mm_compat(weight):
-            return weight
-        if isinstance(weight_granularity, PerRow):
-            assert (
-                weight.dtype == torch.bfloat16
-            ), "PerRow quantization only works for bfloat16 precision input weight"
+    def __post_init__(self):
+        assert (
+            is_sm_at_least_89() or is_MI300()
+        ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
+        if self.mm_config is None:
+            self.mm_config = Float8MMConfig(use_fast_accum=True)
 
-        block_size = get_block_size(weight.shape, weight_granularity)
-        quantized_weight = to_affine_quantized_floatx(
-            input_float=weight,
-            block_size=block_size,
-            target_dtype=weight_dtype,
-            scale_dtype=torch.float32,
-            _layout=Float8Layout(mm_config=mm_config),
-        )
 
-        input_quant_func = _input_activation_quant_func_fp8
-        input_quant_kwargs = {
-            "activation_granularity": activation_granularity,
-            "activation_dtype": activation_dtype,
-        }
+# for bc
+float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
 
-        quantized_weight = to_linear_activation_quantized(
-            quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
-        )
-        return quantized_weight
 
-    return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)
+@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig)
+def _float8_dynamic_activation_float8_weight_transform(
+    module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig
+):
+    activation_dtype = config.activation_dtype
+    weight_dtype = config.weight_dtype
+    granularity = config.granularity
+    mm_config = config.mm_config
+    weight = module.weight
 
+    activation_granularity, weight_granularity = _normalize_granularity(granularity)
 
-def float8_static_activation_float8_weight(
-    scale: torch.Tensor,
-    activation_dtype: torch.dtype = torch.float8_e4m3fn,
-    weight_dtype: torch.dtype = torch.float8_e4m3fn,
-    granularity: Optional[
-        Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
-    ] = None,
-    mm_config: Optional[Float8MMConfig] = None,
-):
+    if not _fp8_mm_compat(weight):
+        # TODO(future PR): this should really throw an exception instead of silently
+        # not doing what the user asked
+        return module
+    if isinstance(weight_granularity, PerRow):
+        assert (
+            weight.dtype == torch.bfloat16
+        ), "PerRow quantization only works for bfloat16 precision input weight"
+
+    block_size = get_block_size(weight.shape, weight_granularity)
+    quantized_weight = to_affine_quantized_floatx(
+        input_float=weight,
+        block_size=block_size,
+        target_dtype=weight_dtype,
+        scale_dtype=torch.float32,
+        _layout=Float8Layout(mm_config=mm_config),
+    )
+
+    input_quant_func = _input_activation_quant_func_fp8
+    input_quant_kwargs = {
+        "activation_granularity": activation_granularity,
+        "activation_dtype": activation_dtype,
+    }
+
+    quantized_weight = to_linear_activation_quantized(
+        quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
+    )
+
+    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
+    module.extra_repr = types.MethodType(_linear_extra_repr, module)
+    return module
+
+
+@dataclass
+class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
     """
-    Applies float8 static symmetric quantization to
+    Configuration for applying float8 static symmetric quantization to
 
     Args:
         scale (torch.Tensor): The scale tensor for activation quantization.
@@ -1249,47 +1275,73 @@ def float8_static_activation_float8_weight(
         weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
         mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
     """
-    assert (
-        is_sm_at_least_89() or is_MI300()
-    ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
-    if mm_config is None:
-        mm_config = Float8MMConfig(use_fast_accum=True)
 
+    scale: torch.Tensor
+    activation_dtype: torch.dtype = torch.float8_e4m3fn
+    weight_dtype: torch.dtype = torch.float8_e4m3fn
+    granularity: Optional[
+        Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
+    ] = None
+    mm_config: Optional[Float8MMConfig] = None
+
+    def __post_init__(self):
+        assert (
+            is_sm_at_least_89() or is_MI300()
+        ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
+        if self.mm_config is None:
+            self.mm_config = Float8MMConfig(use_fast_accum=True)
+
+
+# for bc
+float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
+
+
+@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
+def _float8_static_activation_float8_weight_transform(
+    module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
+):
+    scale = config.scale
+    activation_dtype = config.activation_dtype
+    weight_dtype = config.weight_dtype
+    granularity = config.granularity
+    mm_config = config.mm_config
+
+    weight = module.weight
     activation_granularity, weight_granularity = _normalize_granularity(granularity)
     assert isinstance(
         activation_granularity, PerTensor
     ), "Static quantization only supports PerTensor granularity"
 
-    def apply_float8_static_activation_quant(weight: torch.Tensor):
-        if not _fp8_mm_compat(weight):
-            return weight
-        block_size = get_block_size(weight.shape, weight_granularity)
-        quantized_weight = to_affine_quantized_floatx(
-            input_float=weight,
-            block_size=block_size,
-            target_dtype=weight_dtype,
-            scale_dtype=torch.float32,
-            _layout=Float8Layout(mm_config=mm_config),
-        )
+    if not _fp8_mm_compat(weight):
+        # TODO(future PR): this should really throw an exception instead of silently
+        # not doing what the user asked
+        return module
+    block_size = get_block_size(weight.shape, weight_granularity)
+    quantized_weight = to_affine_quantized_floatx(
+        input_float=weight,
+        block_size=block_size,
+        target_dtype=weight_dtype,
+        scale_dtype=torch.float32,
+        _layout=Float8Layout(mm_config=mm_config),
+    )
 
-        input_quant_func = _input_activation_quant_func_fp8
-        input_quant_kwargs = {
-            "activation_granularity": activation_granularity,
-            "activation_dtype": activation_dtype,
-        }
-
-        quantized_weight = (
-            to_weight_tensor_with_linear_activation_quantization_metadata(
-                quantized_weight,
-                input_quant_func,
-                scale=scale,
-                zero_point=None,
-                quant_kwargs=input_quant_kwargs,
-            )
-        )
-        return quantized_weight
+    input_quant_func = _input_activation_quant_func_fp8
+    input_quant_kwargs = {
+        "activation_granularity": activation_granularity,
+        "activation_dtype": activation_dtype,
+    }
 
-    return _get_linear_subclass_inserter(apply_float8_static_activation_quant)
+    quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(
+        quantized_weight,
+        input_quant_func,
+        scale=scale,
+        zero_point=None,
+        quant_kwargs=input_quant_kwargs,
+    )
+
+    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
+    module.extra_repr = types.MethodType(_linear_extra_repr, module)
+    return module
 
 
 def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):

From 5702ea030a5163cfe53d2b1ff8cf0610b5cea5fc Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Mon, 10 Feb 2025 16:51:31 -0800
Subject: [PATCH 09/10] Update

[ghstack-poisoned]
---
 test/quantization/test_quant_api.py | 13 +++++++++++++
 torchao/quantization/quant_api.py   | 14 ++++++++------
 2 files changed, 21 insertions(+), 6 deletions(-)

diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py
index b9220c2815..61ea2c5558 100644
--- a/test/quantization/test_quant_api.py
+++ b/test/quantization/test_quant_api.py
@@ -49,6 +49,7 @@
     TORCH_VERSION_AT_LEAST_2_4,
     TORCH_VERSION_AT_LEAST_2_5,
     TORCH_VERSION_AT_LEAST_2_6,
+    is_sm_at_least_89,
     unwrap_tensor_subclass,
 )
 
@@ -806,6 +807,18 @@ def test_workflow_e2e_numerics(self, config):
         Simple test of e2e int4_weight_only workflow, comparing numerics
         to a bfloat16 baseline.
         """
+        if (
+            isinstance(
+                config,
+                (
+                    float8_dynamic_activation_float8_weight,
+                    float8_static_activation_float8_weight,
+                ),
+            )
+            and not is_sm_at_least_89()
+        ):
+            return unittest.skip("requires CUDA capability 8.9 or greater")
+
         # set up inputs
         x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
         # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index 01e3a7c029..12ac02096e 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -1208,9 +1208,6 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
     mm_config: Optional[Float8MMConfig] = None
 
     def __post_init__(self):
-        assert (
-            is_sm_at_least_89() or is_MI300()
-        ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
         if self.mm_config is None:
             self.mm_config = Float8MMConfig(use_fast_accum=True)
 
@@ -1223,6 +1220,10 @@ def __post_init__(self):
 def _float8_dynamic_activation_float8_weight_transform(
     module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig
 ):
+    assert (
+        is_sm_at_least_89() or is_MI300()
+    ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
+
     activation_dtype = config.activation_dtype
     weight_dtype = config.weight_dtype
     granularity = config.granularity
@@ -1285,9 +1286,6 @@ class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
     mm_config: Optional[Float8MMConfig] = None
 
     def __post_init__(self):
-        assert (
-            is_sm_at_least_89() or is_MI300()
-        ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
         if self.mm_config is None:
             self.mm_config = Float8MMConfig(use_fast_accum=True)
 
@@ -1300,6 +1298,10 @@ def __post_init__(self):
 def _float8_static_activation_float8_weight_transform(
     module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
 ):
+    assert (
+        is_sm_at_least_89() or is_MI300()
+    ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
+
     scale = config.scale
     activation_dtype = config.activation_dtype
     weight_dtype = config.weight_dtype

From 0542402b263299ac8cc643899f85913a9c037f28 Mon Sep 17 00:00:00 2001
From: vasiliy <vasiliy@fb.com>
Date: Mon, 10 Feb 2025 19:10:03 -0800
Subject: [PATCH 10/10] Update

[ghstack-poisoned]
---
 torchao/quantization/__init__.py     | 2 ++
 torchao/quantization/qat/__init__.py | 4 ++++
 2 files changed, 6 insertions(+)

diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py
index b68ab8a179..71e8de337a 100644
--- a/torchao/quantization/__init__.py
+++ b/torchao/quantization/__init__.py
@@ -46,6 +46,7 @@
     AffineQuantizedObserverBase,
 )
 from .quant_api import (
+    Int4WeightOnlyConfig,
     float8_dynamic_activation_float8_weight,
     float8_static_activation_float8_weight,
     float8_weight_only,
@@ -119,6 +120,7 @@
     "fpx_weight_only",
     "gemlite_uintx_weight_only",
     "swap_conv2d_1x1_to_linear",
+    "Int4WeightOnlyConfig",
     # smooth quant - subject to change
     "get_scale",
     "SmoothFakeDynQuantMixin",
diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py
index 15008e03ea..5dc3d8e008 100644
--- a/torchao/quantization/qat/__init__.py
+++ b/torchao/quantization/qat/__init__.py
@@ -1,6 +1,8 @@
 from .api import (
     ComposableQATQuantizer,
     FakeQuantizeConfig,
+    FromIntXQuantizationAwareTrainingConfig,
+    IntXQuantizationAwareTrainingConfig,
     from_intx_quantization_aware_training,
     intx_quantization_aware_training,
 )
@@ -20,4 +22,6 @@
     "Int8DynActInt4WeightQATQuantizer",
     "intx_quantization_aware_training",
     "from_intx_quantization_aware_training",
+    "FromIntXQuantizationAwareTrainingConfig",
+    "IntXQuantizationAwareTrainingConfig",
 ]