diff --git a/tutorials/calibration_flow/awq_like.py b/tutorials/calibration_flow/awq_like.py index 5742b9b328..c047b8531e 100644 --- a/tutorials/calibration_flow/awq_like.py +++ b/tutorials/calibration_flow/awq_like.py @@ -8,11 +8,13 @@ """ import copy +from dataclasses import dataclass import torch import torch.nn.functional as F from torch import Tensor +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( Float8Layout, to_affine_quantized_floatx_static, @@ -33,6 +35,9 @@ from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error @@ -83,61 +88,72 @@ def replacement_fn(m): _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) +@dataclass +class ApplyAWQConfig(AOBaseConfig): + target_dtype: torch.dtype + + # converting observed linear module to linear module with quantzied weights (and quantized activations) # with tensor subclasses -def apply_awq(target_dtype: torch.dtype): - # target_dtype = torch.uint8 - def _apply_awq_to_linear(observed_linear): - # weight quantization - weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() - - def weight_quant_func(weight): - block_size = (1, weight.shape[1]) - if target_dtype == torch.uint8: - return to_affine_quantized_intx_static( - weight, weight_scale, weight_zero_point, block_size, target_dtype - ) - elif target_dtype == torch.float8_e4m3fn: - return to_affine_quantized_floatx_static( - weight, - weight_scale, - block_size, - target_dtype, - Float8Layout(mm_config=None), - ) - else: - raise ValueError(f"Unsupported target dtype {target_dtype}") - - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - False, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.weight = observed_linear.weight - linear.bias = observed_linear.bias - # activation quantization - # pretend this to be the equalization scale, in reality the `act_obs` should - # be an observer that can caluclate equalization scale - equalization_scale, _ = observed_linear.act_obs.calculate_qparams() - equalization_scale = torch.ones_like(equalization_scale) - linear.weight = torch.nn.Parameter( - weight_quant_func(linear.weight * equalization_scale), requires_grad=False - ) +@register_quantize_module_handler(ApplyAWQConfig) +def _apply_awq_transform( + module: torch.nn.Module, + config: ApplyAWQConfig, +): + target_dtype = config.target_dtype + observed_linear = module - linear.weight = torch.nn.Parameter( - to_weight_tensor_with_linear_activation_scale_metadata( - linear.weight, equalization_scale - ), - requires_grad=False, - ) + # target_dtype = torch.uint8 + # weight quantization + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() + + def weight_quant_func(weight): + block_size = (1, weight.shape[1]) + if target_dtype == torch.uint8: + return to_affine_quantized_intx_static( + weight, weight_scale, weight_zero_point, block_size, target_dtype + ) + elif target_dtype == torch.float8_e4m3fn: + return to_affine_quantized_floatx_static( + weight, + weight_scale, + block_size, + target_dtype, + Float8Layout(mm_config=None), + ) + else: + raise ValueError(f"Unsupported target dtype {target_dtype}") + + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + False, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = observed_linear.weight + linear.bias = observed_linear.bias + + # activation quantization + # pretend this to be the equalization scale, in reality the `act_obs` should + # be an observer that can caluclate equalization scale + equalization_scale, _ = observed_linear.act_obs.calculate_qparams() + equalization_scale = torch.ones_like(equalization_scale) - return linear + linear.weight = torch.nn.Parameter( + weight_quant_func(linear.weight * equalization_scale), requires_grad=False + ) + + linear.weight = torch.nn.Parameter( + to_weight_tensor_with_linear_activation_scale_metadata( + linear.weight, equalization_scale + ), + requires_grad=False, + ) - return _apply_awq_to_linear + return linear ######## Test ########## @@ -201,7 +217,7 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType): # quantized linear represented as an nn.Linear with modified tensor subclass weights # for both activation and weight quantization - quantize_(m, apply_awq(target_dtype), is_observed_linear) + quantize_(m, ApplyAWQConfig(target_dtype), is_observed_linear) print("quantized model (applying tensor subclass to weight):", m) after_quant = m(*example_inputs) assert compute_error(before_quant, after_quant) > 25 diff --git a/tutorials/calibration_flow/gptq_like.py b/tutorials/calibration_flow/gptq_like.py index 93c7e3c4ab..e4f28faf6f 100644 --- a/tutorials/calibration_flow/gptq_like.py +++ b/tutorials/calibration_flow/gptq_like.py @@ -33,6 +33,7 @@ import torch from torch.utils._pytree import tree_flatten, tree_unflatten +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( to_affine_quantized_intx, to_affine_quantized_intx_static, @@ -47,6 +48,9 @@ to_linear_activation_quantized, ) from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error torch.manual_seed(0) @@ -252,36 +256,42 @@ def _register_forward_pre_hook(module: torch.nn.Module): ) -# using a function to align with the API in quant_api -def apply_activation_static_weight_quant(): - def _apply_activation_static_weight_quant(observed_linear): - target_dtype = torch.uint8 - - # we can quantize the weight here as well +class ApplyActivationStaticWeightQuantConfig(AOBaseConfig): + pass - # activation quantization - act_scale, act_zero_point = ( - observed_linear.input_scale, - observed_linear.input_zp, - ) - input_quant_func = lambda x: to_affine_quantized_intx_static( - x, act_scale, act_zero_point, x.shape, target_dtype - ) - # for demo purpose only, we quantize the weight here - weight = observed_linear.weight - weight = to_affine_quantized_intx( - weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8 - ) - observed_linear.weight = torch.nn.Parameter( - to_linear_activation_quantized(weight, input_quant_func), - requires_grad=False, - ) - del observed_linear.input_scale - del observed_linear.input_zp - return observed_linear +# using a function to align with the API in quant_api +@register_quantize_module_handler(ApplyActivationStaticWeightQuantConfig) +def _apply_activation_static_weight_quant_transform( + module: torch.nn.Module, + config: ApplyActivationStaticWeightQuantConfig, +): + observed_linear = module + target_dtype = torch.uint8 + + # we can quantize the weight here as well + + # activation quantization + act_scale, act_zero_point = ( + observed_linear.input_scale, + observed_linear.input_zp, + ) + input_quant_func = lambda x: to_affine_quantized_intx_static( + x, act_scale, act_zero_point, x.shape, target_dtype + ) + # for demo purpose only, we quantize the weight here + weight = observed_linear.weight + weight = to_affine_quantized_intx( + weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8 + ) + observed_linear.weight = torch.nn.Parameter( + to_linear_activation_quantized(weight, input_quant_func), + requires_grad=False, + ) - return _apply_activation_static_weight_quant + del observed_linear.input_scale + del observed_linear.input_zp + return observed_linear example_inputs = (torch.randn(32, 64),) @@ -298,7 +308,7 @@ def _apply_activation_static_weight_quant(observed_linear): # just quantizing activation since we only observed quantization, this could be extended to support # quantizing weight as well -quantize_(m, apply_activation_static_weight_quant(), _is_linear) +quantize_(m, ApplyActivationStaticWeightQuantConfig(), _is_linear) for l in m.modules(): if isinstance(l, torch.nn.Linear): assert isinstance(l.weight, LinearActivationQuantizedTensor) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index fd24a71189..1ebce411d3 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -3,11 +3,13 @@ """ import copy +from dataclasses import dataclass import torch import torch.nn.functional as F from torch import Tensor +from torchao.core.config import AOBaseConfig from torchao.dtypes import ( Float8Layout, to_affine_quantized_floatx_static, @@ -26,6 +28,9 @@ from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) from torchao.quantization.utils import compute_error from torchao.utils import is_sm_at_least_90 @@ -77,66 +82,74 @@ def replacement_fn(m): _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) -# converting observed linear module to linear module with quantzied weights (and quantized activations) -# with tensor subclasses -def apply_static_quant(target_dtype: torch.dtype): - # target_dtype = torch.uint8 - def _apply_static_quant_to_linear(observed_linear): - # weight quantization - weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() - - def weight_quant_func(weight): - block_size = (1, weight.shape[1]) - if target_dtype == torch.uint8: - return to_affine_quantized_intx_static( - weight, weight_scale, weight_zero_point, block_size, target_dtype - ) - elif target_dtype == torch.float8_e4m3fn: - mm_config = Float8MMConfig(use_fast_accum=True) - return to_affine_quantized_floatx_static( - weight, - weight_scale, - block_size, - target_dtype, - Float8Layout(mm_config=mm_config), - ) - else: - raise ValueError(f"Unsupported target dtype {target_dtype}") - - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - False, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.weight = observed_linear.weight - linear.bias = observed_linear.bias +@dataclass +class ApplyStaticQuantConfig(AOBaseConfig): + target_dtype: torch.dtype - linear.weight = torch.nn.Parameter( - weight_quant_func(linear.weight), requires_grad=False - ) - # activation quantization - act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() +# converting observed linear module to linear module with quantzied weights (and quantized activations) +# with tensor subclasses +@register_quantize_module_handler(ApplyStaticQuantConfig) +def _apply_static_quant_transform( + module: torch.nn.Module, + config: ApplyStaticQuantConfig, +): + target_dtype = config.target_dtype + observed_linear = module + + # weight quantization + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() + + def weight_quant_func(weight): + block_size = (1, weight.shape[1]) if target_dtype == torch.uint8: - input_quant_func = lambda x: to_affine_quantized_intx_static( - x, act_scale, act_zero_point, x.shape, target_dtype + return to_affine_quantized_intx_static( + weight, weight_scale, weight_zero_point, block_size, target_dtype ) elif target_dtype == torch.float8_e4m3fn: - input_quant_func = lambda x: to_affine_quantized_floatx_static( - x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None) + mm_config = Float8MMConfig(use_fast_accum=True) + return to_affine_quantized_floatx_static( + weight, + weight_scale, + block_size, + target_dtype, + Float8Layout(mm_config=mm_config), ) else: raise ValueError(f"Unsupported target dtype {target_dtype}") - linear.weight = torch.nn.Parameter( - to_linear_activation_quantized(linear.weight, input_quant_func), - requires_grad=False, - ) - return linear + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + False, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) + linear.weight = observed_linear.weight + linear.bias = observed_linear.bias - return _apply_static_quant_to_linear + linear.weight = torch.nn.Parameter( + weight_quant_func(linear.weight), requires_grad=False + ) + + # activation quantization + act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() + if target_dtype == torch.uint8: + input_quant_func = lambda x: to_affine_quantized_intx_static( + x, act_scale, act_zero_point, x.shape, target_dtype + ) + elif target_dtype == torch.float8_e4m3fn: + input_quant_func = lambda x: to_affine_quantized_floatx_static( + x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None) + ) + else: + raise ValueError(f"Unsupported target dtype {target_dtype}") + linear.weight = torch.nn.Parameter( + to_linear_activation_quantized(linear.weight, input_quant_func), + requires_grad=False, + ) + + return linear # alternative for converting observed linear module to quantized linear module @@ -210,11 +223,17 @@ def from_observed(cls, observed_linear, target_dtype): return quantized_linear -def apply_static_quant2(target_dtype: torch.dtype): - def _apply_static_quant2(observed_linear): - return QuantizedLinear.from_observed(observed_linear, target_dtype) +@dataclass +class ApplyStaticQuantConfig2(AOBaseConfig): + target_dtype: torch.dtype + - return _apply_static_quant2 +@register_quantize_module_handler(ApplyStaticQuantConfig2) +def apply_static_quant( + module: torch.nn.Module, + config: ApplyStaticQuantConfig2, +): + return QuantizedLinear.from_observed(module, config.target_dtype) class ToyLinearModel(torch.nn.Module): @@ -281,14 +300,14 @@ def test_static_quant(target_dtype: torch.dtype, mapping_type: MappingType): # quantized linear represented as an nn.Linear with modified tensor subclass weights # for both activation and weight quantization - quantize_(m, apply_static_quant(target_dtype), is_observed_linear) + quantize_(m, ApplyStaticQuantConfig(target_dtype), is_observed_linear) print("quantized model (applying tensor subclass to weight):", m) after_quant = m(*example_inputs) assert compute_error(before_quant, after_quant) > 25 print("test passed") # quantized linear as a standalone module - quantize_(m2, apply_static_quant2(target_dtype), is_observed_linear) + quantize_(m2, ApplyStaticQuantConfig2(target_dtype), is_observed_linear) print("quantized model (quantized module):", m2) after_quant = m2(*example_inputs) assert compute_error(before_quant, after_quant) > 25