|
8 | 8 | """
|
9 | 9 |
|
10 | 10 | import copy
|
| 11 | +from dataclasses import dataclass |
11 | 12 |
|
12 | 13 | import torch
|
13 | 14 | import torch.nn.functional as F
|
14 | 15 | from torch import Tensor
|
15 | 16 |
|
| 17 | +from torchao.core.config import AOBaseConfig |
16 | 18 | from torchao.dtypes import (
|
17 | 19 | Float8Layout,
|
18 | 20 | to_affine_quantized_floatx_static,
|
|
33 | 35 | from torchao.quantization.quant_primitives import (
|
34 | 36 | MappingType,
|
35 | 37 | )
|
| 38 | +from torchao.quantization.transform_module import ( |
| 39 | + register_quantize_module_handler, |
| 40 | +) |
36 | 41 | from torchao.quantization.utils import compute_error
|
37 | 42 |
|
38 | 43 |
|
@@ -83,61 +88,72 @@ def replacement_fn(m):
|
83 | 88 | _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)
|
84 | 89 |
|
85 | 90 |
|
| 91 | +@dataclass |
| 92 | +class ApplyAWQConfig(AOBaseConfig): |
| 93 | + target_dtype: torch.dtype |
| 94 | + |
| 95 | + |
86 | 96 | # converting observed linear module to linear module with quantzied weights (and quantized activations)
|
87 | 97 | # with tensor subclasses
|
88 |
| -def apply_awq(target_dtype: torch.dtype): |
89 |
| - # target_dtype = torch.uint8 |
90 |
| - def _apply_awq_to_linear(observed_linear): |
91 |
| - # weight quantization |
92 |
| - weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() |
93 |
| - |
94 |
| - def weight_quant_func(weight): |
95 |
| - block_size = (1, weight.shape[1]) |
96 |
| - if target_dtype == torch.uint8: |
97 |
| - return to_affine_quantized_intx_static( |
98 |
| - weight, weight_scale, weight_zero_point, block_size, target_dtype |
99 |
| - ) |
100 |
| - elif target_dtype == torch.float8_e4m3fn: |
101 |
| - return to_affine_quantized_floatx_static( |
102 |
| - weight, |
103 |
| - weight_scale, |
104 |
| - block_size, |
105 |
| - target_dtype, |
106 |
| - Float8Layout(mm_config=None), |
107 |
| - ) |
108 |
| - else: |
109 |
| - raise ValueError(f"Unsupported target dtype {target_dtype}") |
110 |
| - |
111 |
| - linear = torch.nn.Linear( |
112 |
| - observed_linear.in_features, |
113 |
| - observed_linear.out_features, |
114 |
| - False, |
115 |
| - device=observed_linear.weight.device, |
116 |
| - dtype=observed_linear.weight.dtype, |
117 |
| - ) |
118 |
| - linear.weight = observed_linear.weight |
119 |
| - linear.bias = observed_linear.bias |
120 | 98 |
|
121 |
| - # activation quantization |
122 |
| - # pretend this to be the equalization scale, in reality the `act_obs` should |
123 |
| - # be an observer that can caluclate equalization scale |
124 |
| - equalization_scale, _ = observed_linear.act_obs.calculate_qparams() |
125 |
| - equalization_scale = torch.ones_like(equalization_scale) |
126 | 99 |
|
127 |
| - linear.weight = torch.nn.Parameter( |
128 |
| - weight_quant_func(linear.weight * equalization_scale), requires_grad=False |
129 |
| - ) |
| 100 | +@register_quantize_module_handler(ApplyAWQConfig) |
| 101 | +def _apply_awq_transform( |
| 102 | + module: torch.nn.Module, |
| 103 | + config: ApplyAWQConfig, |
| 104 | +): |
| 105 | + target_dtype = config.target_dtype |
| 106 | + observed_linear = module |
130 | 107 |
|
131 |
| - linear.weight = torch.nn.Parameter( |
132 |
| - to_weight_tensor_with_linear_activation_scale_metadata( |
133 |
| - linear.weight, equalization_scale |
134 |
| - ), |
135 |
| - requires_grad=False, |
136 |
| - ) |
| 108 | + # target_dtype = torch.uint8 |
| 109 | + # weight quantization |
| 110 | + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() |
| 111 | + |
| 112 | + def weight_quant_func(weight): |
| 113 | + block_size = (1, weight.shape[1]) |
| 114 | + if target_dtype == torch.uint8: |
| 115 | + return to_affine_quantized_intx_static( |
| 116 | + weight, weight_scale, weight_zero_point, block_size, target_dtype |
| 117 | + ) |
| 118 | + elif target_dtype == torch.float8_e4m3fn: |
| 119 | + return to_affine_quantized_floatx_static( |
| 120 | + weight, |
| 121 | + weight_scale, |
| 122 | + block_size, |
| 123 | + target_dtype, |
| 124 | + Float8Layout(mm_config=None), |
| 125 | + ) |
| 126 | + else: |
| 127 | + raise ValueError(f"Unsupported target dtype {target_dtype}") |
| 128 | + |
| 129 | + linear = torch.nn.Linear( |
| 130 | + observed_linear.in_features, |
| 131 | + observed_linear.out_features, |
| 132 | + False, |
| 133 | + device=observed_linear.weight.device, |
| 134 | + dtype=observed_linear.weight.dtype, |
| 135 | + ) |
| 136 | + linear.weight = observed_linear.weight |
| 137 | + linear.bias = observed_linear.bias |
| 138 | + |
| 139 | + # activation quantization |
| 140 | + # pretend this to be the equalization scale, in reality the `act_obs` should |
| 141 | + # be an observer that can caluclate equalization scale |
| 142 | + equalization_scale, _ = observed_linear.act_obs.calculate_qparams() |
| 143 | + equalization_scale = torch.ones_like(equalization_scale) |
137 | 144 |
|
138 |
| - return linear |
| 145 | + linear.weight = torch.nn.Parameter( |
| 146 | + weight_quant_func(linear.weight * equalization_scale), requires_grad=False |
| 147 | + ) |
| 148 | + |
| 149 | + linear.weight = torch.nn.Parameter( |
| 150 | + to_weight_tensor_with_linear_activation_scale_metadata( |
| 151 | + linear.weight, equalization_scale |
| 152 | + ), |
| 153 | + requires_grad=False, |
| 154 | + ) |
139 | 155 |
|
140 |
| - return _apply_awq_to_linear |
| 156 | + return linear |
141 | 157 |
|
142 | 158 |
|
143 | 159 | ######## Test ##########
|
@@ -201,7 +217,7 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
|
201 | 217 |
|
202 | 218 | # quantized linear represented as an nn.Linear with modified tensor subclass weights
|
203 | 219 | # for both activation and weight quantization
|
204 |
| - quantize_(m, apply_awq(target_dtype), is_observed_linear) |
| 220 | + quantize_(m, ApplyAWQConfig(target_dtype), is_observed_linear) |
205 | 221 | print("quantized model (applying tensor subclass to weight):", m)
|
206 | 222 | after_quant = m(*example_inputs)
|
207 | 223 | assert compute_error(before_quant, after_quant) > 25
|
|
0 commit comments