diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py
index 53ca470b04..d26f1d8e04 100644
--- a/test/dtypes/test_affine_quantized.py
+++ b/test/dtypes/test_affine_quantized.py
@@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant):
     @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
     @common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
     def test_to_device(self, apply_quant):
+        def _apply(module, config_or_subclass_inserter):
+            if isinstance(config_or_subclass_inserter, AOBaseConfig):
+                quantize_(module, config_or_subclass_inserter)
+            else:
+                # TODO(#1690): delete this once config migration is done
+                module = config_or_subclass_inserter(module)
+            return module
+
         linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
-        ql = apply_quant(linear)
+        ql = _apply(linear, apply_quant)
         ql.to("cuda")
 
         linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
-        ql = apply_quant(linear)
+        ql = _apply(linear, apply_quant)
         ql.to(device="cuda")
 
         linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
-        ql = apply_quant(linear)
+        ql = _apply(linear, apply_quant)
         ql.cuda()
 
     @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py
index acd9b50c5a..e0f6cb1ace 100644
--- a/test/quantization/test_quant_api.py
+++ b/test/quantization/test_quant_api.py
@@ -30,6 +30,9 @@
     Quantizer,
     TwoStepQuantizer,
     _replace_with_custom_fn_if_matches_filter,
+    float8_dynamic_activation_float8_weight,
+    float8_static_activation_float8_weight,
+    float8_weight_only,
     int4_weight_only,
     int8_dynamic_activation_int4_weight,
     int8_dynamic_activation_int8_weight,
@@ -46,6 +49,7 @@
     TORCH_VERSION_AT_LEAST_2_4,
     TORCH_VERSION_AT_LEAST_2_5,
     TORCH_VERSION_AT_LEAST_2_6,
+    is_sm_at_least_89,
     unwrap_tensor_subclass,
 )
 
@@ -784,28 +788,55 @@ def test_int4wo_cpu(self, dtype, x_dim):
             assert "_weight_int4pack_mm_for_cpu" in code[0]
             assert "aten.mm.default" not in code[0]
 
+    # TODO(#1690): move to new config names
     @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
     @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
-    def test_int4_weight_only_numerics(self):
+    @common_utils.parametrize(
+        "config",
+        [
+            int4_weight_only(),
+            float8_weight_only(),
+            float8_dynamic_activation_float8_weight(),
+            float8_static_activation_float8_weight(scale=torch.tensor([1.0])),
+        ],
+    )
+    def test_workflow_e2e_numerics(self, config):
         """
         Simple test of e2e int4_weight_only workflow, comparing numerics
         to a bfloat16 baseline.
         """
+        if (
+            isinstance(
+                config,
+                (
+                    float8_dynamic_activation_float8_weight,
+                    float8_static_activation_float8_weight,
+                ),
+            )
+            and not is_sm_at_least_89()
+        ):
+            return unittest.skip("requires CUDA capability 8.9 or greater")
+
+        # scale has to be moved to cuda here because the parametrization init
+        # code happens before gating for cuda availability
+        if isinstance(config, float8_static_activation_float8_weight):
+            config.scale = config.scale.to("cuda")
+
         # set up inputs
         x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
         # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
         # is that expected?
         m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
-        m_int4_wo = copy.deepcopy(m_ref)
+        m_q = copy.deepcopy(m_ref)
 
         # quantize
-        quantize_(m_int4_wo, int4_weight_only())
+        quantize_(m_q, config)
 
         with torch.no_grad():
             y_ref = m_ref(x)
-            y_int4_wo = m_int4_wo(x)
+            y_q = m_q(x)
 
-        sqnr = compute_error(y_ref, y_int4_wo)
+        sqnr = compute_error(y_ref, y_q)
         assert sqnr >= 20, f"SQNR {sqnr} is too low"
 
 
diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py
index 71e8de337a..ca9a4141fc 100644
--- a/torchao/quantization/__init__.py
+++ b/torchao/quantization/__init__.py
@@ -46,6 +46,9 @@
     AffineQuantizedObserverBase,
 )
 from .quant_api import (
+    Float8DynamicActivationFloat8WeightConfig,
+    Float8StaticActivationFloat8WeightConfig,
+    Float8WeightOnlyConfig,
     Int4WeightOnlyConfig,
     float8_dynamic_activation_float8_weight,
     float8_static_activation_float8_weight,
@@ -121,6 +124,9 @@
     "gemlite_uintx_weight_only",
     "swap_conv2d_1x1_to_linear",
     "Int4WeightOnlyConfig",
+    "Float8WeightOnlyConfig",
+    "Float8DynamicActivationFloat8WeightConfig",
+    "Float8StaticActivationFloat8WeightConfig",
     # smooth quant - subject to change
     "get_scale",
     "SmoothFakeDynQuantMixin",
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index 9f6599c177..6e5e043fb0 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
     return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
 
 
-def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
+@dataclass
+class Float8WeightOnlyConfig(AOBaseConfig):
     """
-    Applies float8 weight-only symmetric per-channel quantization to linear layers.
+    Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers.
 
     Args:
         weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
 
     Note:
         The actual matmul will be computed in original precision of the weight tensor.
-
     """
-    from torchao.dtypes import to_affine_quantized_floatx
 
-    def apply_float8wo_quant(weight):
-        block_size = (1, weight.shape[1])
-        return to_affine_quantized_floatx(
-            input_float=weight,
-            block_size=block_size,
-            target_dtype=weight_dtype,
-            scale_dtype=None,
-            _layout=Float8Layout(mm_config=None),
-        )
+    weight_dtype: torch.dtype = torch.float8_e4m3fn
 
-    return _get_linear_subclass_inserter(apply_float8wo_quant)
+
+# for BC
+float8_weight_only = Float8WeightOnlyConfig
+
+
+@register_quantize_module_handler(Float8WeightOnlyConfig)
+def _float8_weight_only_transform(
+    module: torch.nn.Module, config: Float8WeightOnlyConfig
+) -> torch.nn.Module:
+    from torchao.dtypes import to_affine_quantized_floatx
+
+    weight = module.weight
+    block_size = (1, weight.shape[1])
+    new_weight = to_affine_quantized_floatx(
+        input_float=weight,
+        block_size=block_size,
+        target_dtype=config.weight_dtype,
+        scale_dtype=None,
+        _layout=Float8Layout(mm_config=None),
+    )
+    module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
+    module.extra_repr = types.MethodType(_linear_extra_repr, module)
+    return module
 
 
 _fp8_granularities = Union[PerTensor, PerRow]
@@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool:
     return is_compatible
 
 
-def float8_dynamic_activation_float8_weight(
-    activation_dtype: torch.dtype = torch.float8_e4m3fn,
-    weight_dtype: torch.dtype = torch.float8_e4m3fn,
-    granularity: Optional[
-        Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
-    ] = None,
-    mm_config: Optional[Float8MMConfig] = None,
-):
+@dataclass
+class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
     """
-    Applies float8 dynamic symmetric quantization to both activations and weights of linear layers.
+    Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers.
 
     Args:
         activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn.
@@ -1192,56 +1199,76 @@ def float8_dynamic_activation_float8_weight(
         mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
 
     """
+
+    activation_dtype: torch.dtype = torch.float8_e4m3fn
+    weight_dtype: torch.dtype = torch.float8_e4m3fn
+    granularity: Optional[
+        Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
+    ] = None
+    mm_config: Optional[Float8MMConfig] = None
+
+    def __post_init__(self):
+        if self.mm_config is None:
+            self.mm_config = Float8MMConfig(use_fast_accum=True)
+
+
+# for bc
+float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig
+
+
+@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig)
+def _float8_dynamic_activation_float8_weight_transform(
+    module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig
+):
     assert (
         is_sm_at_least_89() or is_MI300()
     ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
-    if mm_config is None:
-        mm_config = Float8MMConfig(use_fast_accum=True)
 
-    activation_granularity, weight_granularity = _normalize_granularity(granularity)
+    activation_dtype = config.activation_dtype
+    weight_dtype = config.weight_dtype
+    granularity = config.granularity
+    mm_config = config.mm_config
+    weight = module.weight
 
-    def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
-        if not _fp8_mm_compat(weight):
-            return weight
-        if isinstance(weight_granularity, PerRow):
-            assert (
-                weight.dtype == torch.bfloat16
-            ), "PerRow quantization only works for bfloat16 precision input weight"
+    activation_granularity, weight_granularity = _normalize_granularity(granularity)
 
-        block_size = get_block_size(weight.shape, weight_granularity)
-        quantized_weight = to_affine_quantized_floatx(
-            input_float=weight,
-            block_size=block_size,
-            target_dtype=weight_dtype,
-            scale_dtype=torch.float32,
-            _layout=Float8Layout(mm_config=mm_config),
-        )
+    if not _fp8_mm_compat(weight):
+        # TODO(future PR): this should really throw an exception instead of silently
+        # not doing what the user asked
+        return module
+    if isinstance(weight_granularity, PerRow):
+        assert (
+            weight.dtype == torch.bfloat16
+        ), "PerRow quantization only works for bfloat16 precision input weight"
+
+    block_size = get_block_size(weight.shape, weight_granularity)
+    quantized_weight = to_affine_quantized_floatx(
+        input_float=weight,
+        block_size=block_size,
+        target_dtype=weight_dtype,
+        scale_dtype=torch.float32,
+        _layout=Float8Layout(mm_config=mm_config),
+    )
 
-        input_quant_func = _input_activation_quant_func_fp8
-        input_quant_kwargs = {
-            "activation_granularity": activation_granularity,
-            "activation_dtype": activation_dtype,
-        }
+    input_quant_func = _input_activation_quant_func_fp8
+    input_quant_kwargs = {
+        "activation_granularity": activation_granularity,
+        "activation_dtype": activation_dtype,
+    }
 
-        quantized_weight = to_linear_activation_quantized(
-            quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
-        )
-        return quantized_weight
+    quantized_weight = to_linear_activation_quantized(
+        quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
+    )
 
-    return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)
+    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
+    module.extra_repr = types.MethodType(_linear_extra_repr, module)
+    return module
 
 
-def float8_static_activation_float8_weight(
-    scale: torch.Tensor,
-    activation_dtype: torch.dtype = torch.float8_e4m3fn,
-    weight_dtype: torch.dtype = torch.float8_e4m3fn,
-    granularity: Optional[
-        Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
-    ] = None,
-    mm_config: Optional[Float8MMConfig] = None,
-):
+@dataclass
+class Float8StaticActivationFloat8WeightConfig(AOBaseConfig):
     """
-    Applies float8 static symmetric quantization to
+    Configuration for applying float8 static symmetric quantization to
 
     Args:
         scale (torch.Tensor): The scale tensor for activation quantization.
@@ -1249,47 +1276,74 @@ def float8_static_activation_float8_weight(
         weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
         mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
     """
+
+    scale: torch.Tensor
+    activation_dtype: torch.dtype = torch.float8_e4m3fn
+    weight_dtype: torch.dtype = torch.float8_e4m3fn
+    granularity: Optional[
+        Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
+    ] = None
+    mm_config: Optional[Float8MMConfig] = None
+
+    def __post_init__(self):
+        if self.mm_config is None:
+            self.mm_config = Float8MMConfig(use_fast_accum=True)
+
+
+# for bc
+float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig
+
+
+@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
+def _float8_static_activation_float8_weight_transform(
+    module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
+):
     assert (
         is_sm_at_least_89() or is_MI300()
     ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
-    if mm_config is None:
-        mm_config = Float8MMConfig(use_fast_accum=True)
 
+    scale = config.scale
+    activation_dtype = config.activation_dtype
+    weight_dtype = config.weight_dtype
+    granularity = config.granularity
+    mm_config = config.mm_config
+
+    weight = module.weight
     activation_granularity, weight_granularity = _normalize_granularity(granularity)
     assert isinstance(
         activation_granularity, PerTensor
     ), "Static quantization only supports PerTensor granularity"
 
-    def apply_float8_static_activation_quant(weight: torch.Tensor):
-        if not _fp8_mm_compat(weight):
-            return weight
-        block_size = get_block_size(weight.shape, weight_granularity)
-        quantized_weight = to_affine_quantized_floatx(
-            input_float=weight,
-            block_size=block_size,
-            target_dtype=weight_dtype,
-            scale_dtype=torch.float32,
-            _layout=Float8Layout(mm_config=mm_config),
-        )
+    if not _fp8_mm_compat(weight):
+        # TODO(future PR): this should really throw an exception instead of silently
+        # not doing what the user asked
+        return module
+    block_size = get_block_size(weight.shape, weight_granularity)
+    quantized_weight = to_affine_quantized_floatx(
+        input_float=weight,
+        block_size=block_size,
+        target_dtype=weight_dtype,
+        scale_dtype=torch.float32,
+        _layout=Float8Layout(mm_config=mm_config),
+    )
 
-        input_quant_func = _input_activation_quant_func_fp8
-        input_quant_kwargs = {
-            "activation_granularity": activation_granularity,
-            "activation_dtype": activation_dtype,
-        }
-
-        quantized_weight = (
-            to_weight_tensor_with_linear_activation_quantization_metadata(
-                quantized_weight,
-                input_quant_func,
-                scale=scale,
-                zero_point=None,
-                quant_kwargs=input_quant_kwargs,
-            )
-        )
-        return quantized_weight
+    input_quant_func = _input_activation_quant_func_fp8
+    input_quant_kwargs = {
+        "activation_granularity": activation_granularity,
+        "activation_dtype": activation_dtype,
+    }
 
-    return _get_linear_subclass_inserter(apply_float8_static_activation_quant)
+    quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata(
+        quantized_weight,
+        input_quant_func,
+        scale=scale,
+        zero_point=None,
+        quant_kwargs=input_quant_kwargs,
+    )
+
+    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
+    module.extra_repr = types.MethodType(_linear_extra_repr, module)
+    return module
 
 
 def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):