diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 53ca470b04..d26f1d8e04 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -123,16 +123,24 @@ def test_weights_only(self, apply_quant): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.parametrize("apply_quant", get_quantization_functions(False, False)) def test_to_device(self, apply_quant): + def _apply(module, config_or_subclass_inserter): + if isinstance(config_or_subclass_inserter, AOBaseConfig): + quantize_(module, config_or_subclass_inserter) + else: + # TODO(#1690): delete this once config migration is done + module = config_or_subclass_inserter(module) + return module + linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.to("cuda") linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.to(device="cuda") linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16) - ql = apply_quant(linear) + ql = _apply(linear, apply_quant) ql.cuda() @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index acd9b50c5a..e0f6cb1ace 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -30,6 +30,9 @@ Quantizer, TwoStepQuantizer, _replace_with_custom_fn_if_matches_filter, + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + float8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, @@ -46,6 +49,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + is_sm_at_least_89, unwrap_tensor_subclass, ) @@ -784,28 +788,55 @@ def test_int4wo_cpu(self, dtype, x_dim): assert "_weight_int4pack_mm_for_cpu" in code[0] assert "aten.mm.default" not in code[0] + # TODO(#1690): move to new config names @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_int4_weight_only_numerics(self): + @common_utils.parametrize( + "config", + [ + int4_weight_only(), + float8_weight_only(), + float8_dynamic_activation_float8_weight(), + float8_static_activation_float8_weight(scale=torch.tensor([1.0])), + ], + ) + def test_workflow_e2e_numerics(self, config): """ Simple test of e2e int4_weight_only workflow, comparing numerics to a bfloat16 baseline. """ + if ( + isinstance( + config, + ( + float8_dynamic_activation_float8_weight, + float8_static_activation_float8_weight, + ), + ) + and not is_sm_at_least_89() + ): + return unittest.skip("requires CUDA capability 8.9 or greater") + + # scale has to be moved to cuda here because the parametrization init + # code happens before gating for cuda availability + if isinstance(config, float8_static_activation_float8_weight): + config.scale = config.scale.to("cuda") + # set up inputs x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() - m_int4_wo = copy.deepcopy(m_ref) + m_q = copy.deepcopy(m_ref) # quantize - quantize_(m_int4_wo, int4_weight_only()) + quantize_(m_q, config) with torch.no_grad(): y_ref = m_ref(x) - y_int4_wo = m_int4_wo(x) + y_q = m_q(x) - sqnr = compute_error(y_ref, y_int4_wo) + sqnr = compute_error(y_ref, y_q) assert sqnr >= 20, f"SQNR {sqnr} is too low" diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 71e8de337a..ca9a4141fc 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -46,6 +46,9 @@ AffineQuantizedObserverBase, ) from .quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Float8StaticActivationFloat8WeightConfig, + Float8WeightOnlyConfig, Int4WeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, @@ -121,6 +124,9 @@ "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", "Int4WeightOnlyConfig", + "Float8WeightOnlyConfig", + "Float8DynamicActivationFloat8WeightConfig", + "Float8StaticActivationFloat8WeightConfig", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9f6599c177..6e5e043fb0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()) -def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): +@dataclass +class Float8WeightOnlyConfig(AOBaseConfig): """ - Applies float8 weight-only symmetric per-channel quantization to linear layers. + Configuration for applying float8 weight-only symmetric per-channel quantization to linear layers. Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. Note: The actual matmul will be computed in original precision of the weight tensor. - """ - from torchao.dtypes import to_affine_quantized_floatx - def apply_float8wo_quant(weight): - block_size = (1, weight.shape[1]) - return to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=None, - _layout=Float8Layout(mm_config=None), - ) + weight_dtype: torch.dtype = torch.float8_e4m3fn - return _get_linear_subclass_inserter(apply_float8wo_quant) + +# for BC +float8_weight_only = Float8WeightOnlyConfig + + +@register_quantize_module_handler(Float8WeightOnlyConfig) +def _float8_weight_only_transform( + module: torch.nn.Module, config: Float8WeightOnlyConfig +) -> torch.nn.Module: + from torchao.dtypes import to_affine_quantized_floatx + + weight = module.weight + block_size = (1, weight.shape[1]) + new_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=config.weight_dtype, + scale_dtype=None, + _layout=Float8Layout(mm_config=None), + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module _fp8_granularities = Union[PerTensor, PerRow] @@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool: return is_compatible -def float8_dynamic_activation_float8_weight( - activation_dtype: torch.dtype = torch.float8_e4m3fn, - weight_dtype: torch.dtype = torch.float8_e4m3fn, - granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] - ] = None, - mm_config: Optional[Float8MMConfig] = None, -): +@dataclass +class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): """ - Applies float8 dynamic symmetric quantization to both activations and weights of linear layers. + Configuration for applying float8 dynamic symmetric quantization to both activations and weights of linear layers. Args: activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. @@ -1192,56 +1199,76 @@ def float8_dynamic_activation_float8_weight( mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] = None + mm_config: Optional[Float8MMConfig] = None + + def __post_init__(self): + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + + +# for bc +float8_dynamic_activation_float8_weight = Float8DynamicActivationFloat8WeightConfig + + +@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) +def _float8_dynamic_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig +): assert ( is_sm_at_least_89() or is_MI300() ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - if mm_config is None: - mm_config = Float8MMConfig(use_fast_accum=True) - activation_granularity, weight_granularity = _normalize_granularity(granularity) + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + weight = module.weight - def apply_float8_dynamic_activation_quant(weight: torch.Tensor): - if not _fp8_mm_compat(weight): - return weight - if isinstance(weight_granularity, PerRow): - assert ( - weight.dtype == torch.bfloat16 - ), "PerRow quantization only works for bfloat16 precision input weight" + activation_granularity, weight_granularity = _normalize_granularity(granularity) - block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return module + if isinstance(weight_granularity, PerRow): + assert ( + weight.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input weight" + + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs - ) - return quantized_weight + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs + ) - return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def float8_static_activation_float8_weight( - scale: torch.Tensor, - activation_dtype: torch.dtype = torch.float8_e4m3fn, - weight_dtype: torch.dtype = torch.float8_e4m3fn, - granularity: Optional[ - Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] - ] = None, - mm_config: Optional[Float8MMConfig] = None, -): +@dataclass +class Float8StaticActivationFloat8WeightConfig(AOBaseConfig): """ - Applies float8 static symmetric quantization to + Configuration for applying float8 static symmetric quantization to Args: scale (torch.Tensor): The scale tensor for activation quantization. @@ -1249,47 +1276,74 @@ def float8_static_activation_float8_weight( weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. """ + + scale: torch.Tensor + activation_dtype: torch.dtype = torch.float8_e4m3fn + weight_dtype: torch.dtype = torch.float8_e4m3fn + granularity: Optional[ + Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]] + ] = None + mm_config: Optional[Float8MMConfig] = None + + def __post_init__(self): + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + + +# for bc +float8_static_activation_float8_weight = Float8StaticActivationFloat8WeightConfig + + +@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) +def _float8_static_activation_float8_weight_transform( + module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig +): assert ( is_sm_at_least_89() or is_MI300() ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" - if mm_config is None: - mm_config = Float8MMConfig(use_fast_accum=True) + scale = config.scale + activation_dtype = config.activation_dtype + weight_dtype = config.weight_dtype + granularity = config.granularity + mm_config = config.mm_config + + weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) assert isinstance( activation_granularity, PerTensor ), "Static quantization only supports PerTensor granularity" - def apply_float8_static_activation_quant(weight: torch.Tensor): - if not _fp8_mm_compat(weight): - return weight - block_size = get_block_size(weight.shape, weight_granularity) - quantized_weight = to_affine_quantized_floatx( - input_float=weight, - block_size=block_size, - target_dtype=weight_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), - ) + if not _fp8_mm_compat(weight): + # TODO(future PR): this should really throw an exception instead of silently + # not doing what the user asked + return module + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = to_affine_quantized_floatx( + input_float=weight, + block_size=block_size, + target_dtype=weight_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=mm_config), + ) - input_quant_func = _input_activation_quant_func_fp8 - input_quant_kwargs = { - "activation_granularity": activation_granularity, - "activation_dtype": activation_dtype, - } - - quantized_weight = ( - to_weight_tensor_with_linear_activation_quantization_metadata( - quantized_weight, - input_quant_func, - scale=scale, - zero_point=None, - quant_kwargs=input_quant_kwargs, - ) - ) - return quantized_weight + input_quant_func = _input_activation_quant_func_fp8 + input_quant_kwargs = { + "activation_granularity": activation_granularity, + "activation_dtype": activation_dtype, + } - return _get_linear_subclass_inserter(apply_float8_static_activation_quant) + quantized_weight = to_weight_tensor_with_linear_activation_quantization_metadata( + quantized_weight, + input_quant_func, + scale=scale, + zero_point=None, + quant_kwargs=input_quant_kwargs, + ) + + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):