Skip to content

Commit 3fa8e44

Browse files
authored
migrate static quant tutorials to direct configuration (#1710)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 17b9ce3 commit 3fa8e44

File tree

3 files changed

+178
-133
lines changed

3 files changed

+178
-133
lines changed

tutorials/calibration_flow/awq_like.py

+65-49
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
"""
99

1010
import copy
11+
from dataclasses import dataclass
1112

1213
import torch
1314
import torch.nn.functional as F
1415
from torch import Tensor
1516

17+
from torchao.core.config import AOBaseConfig
1618
from torchao.dtypes import (
1719
Float8Layout,
1820
to_affine_quantized_floatx_static,
@@ -33,6 +35,9 @@
3335
from torchao.quantization.quant_primitives import (
3436
MappingType,
3537
)
38+
from torchao.quantization.transform_module import (
39+
register_quantize_module_handler,
40+
)
3641
from torchao.quantization.utils import compute_error
3742

3843

@@ -83,61 +88,72 @@ def replacement_fn(m):
8388
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)
8489

8590

91+
@dataclass
92+
class ApplyAWQConfig(AOBaseConfig):
93+
target_dtype: torch.dtype
94+
95+
8696
# converting observed linear module to linear module with quantzied weights (and quantized activations)
8797
# with tensor subclasses
88-
def apply_awq(target_dtype: torch.dtype):
89-
# target_dtype = torch.uint8
90-
def _apply_awq_to_linear(observed_linear):
91-
# weight quantization
92-
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
93-
94-
def weight_quant_func(weight):
95-
block_size = (1, weight.shape[1])
96-
if target_dtype == torch.uint8:
97-
return to_affine_quantized_intx_static(
98-
weight, weight_scale, weight_zero_point, block_size, target_dtype
99-
)
100-
elif target_dtype == torch.float8_e4m3fn:
101-
return to_affine_quantized_floatx_static(
102-
weight,
103-
weight_scale,
104-
block_size,
105-
target_dtype,
106-
Float8Layout(mm_config=None),
107-
)
108-
else:
109-
raise ValueError(f"Unsupported target dtype {target_dtype}")
110-
111-
linear = torch.nn.Linear(
112-
observed_linear.in_features,
113-
observed_linear.out_features,
114-
False,
115-
device=observed_linear.weight.device,
116-
dtype=observed_linear.weight.dtype,
117-
)
118-
linear.weight = observed_linear.weight
119-
linear.bias = observed_linear.bias
12098

121-
# activation quantization
122-
# pretend this to be the equalization scale, in reality the `act_obs` should
123-
# be an observer that can caluclate equalization scale
124-
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
125-
equalization_scale = torch.ones_like(equalization_scale)
12699

127-
linear.weight = torch.nn.Parameter(
128-
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
129-
)
100+
@register_quantize_module_handler(ApplyAWQConfig)
101+
def _apply_awq_transform(
102+
module: torch.nn.Module,
103+
config: ApplyAWQConfig,
104+
):
105+
target_dtype = config.target_dtype
106+
observed_linear = module
130107

131-
linear.weight = torch.nn.Parameter(
132-
to_weight_tensor_with_linear_activation_scale_metadata(
133-
linear.weight, equalization_scale
134-
),
135-
requires_grad=False,
136-
)
108+
# target_dtype = torch.uint8
109+
# weight quantization
110+
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
111+
112+
def weight_quant_func(weight):
113+
block_size = (1, weight.shape[1])
114+
if target_dtype == torch.uint8:
115+
return to_affine_quantized_intx_static(
116+
weight, weight_scale, weight_zero_point, block_size, target_dtype
117+
)
118+
elif target_dtype == torch.float8_e4m3fn:
119+
return to_affine_quantized_floatx_static(
120+
weight,
121+
weight_scale,
122+
block_size,
123+
target_dtype,
124+
Float8Layout(mm_config=None),
125+
)
126+
else:
127+
raise ValueError(f"Unsupported target dtype {target_dtype}")
128+
129+
linear = torch.nn.Linear(
130+
observed_linear.in_features,
131+
observed_linear.out_features,
132+
False,
133+
device=observed_linear.weight.device,
134+
dtype=observed_linear.weight.dtype,
135+
)
136+
linear.weight = observed_linear.weight
137+
linear.bias = observed_linear.bias
138+
139+
# activation quantization
140+
# pretend this to be the equalization scale, in reality the `act_obs` should
141+
# be an observer that can caluclate equalization scale
142+
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
143+
equalization_scale = torch.ones_like(equalization_scale)
137144

138-
return linear
145+
linear.weight = torch.nn.Parameter(
146+
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
147+
)
148+
149+
linear.weight = torch.nn.Parameter(
150+
to_weight_tensor_with_linear_activation_scale_metadata(
151+
linear.weight, equalization_scale
152+
),
153+
requires_grad=False,
154+
)
139155

140-
return _apply_awq_to_linear
156+
return linear
141157

142158

143159
######## Test ##########
@@ -201,7 +217,7 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):
201217

202218
# quantized linear represented as an nn.Linear with modified tensor subclass weights
203219
# for both activation and weight quantization
204-
quantize_(m, apply_awq(target_dtype), is_observed_linear)
220+
quantize_(m, ApplyAWQConfig(target_dtype), is_observed_linear)
205221
print("quantized model (applying tensor subclass to weight):", m)
206222
after_quant = m(*example_inputs)
207223
assert compute_error(before_quant, after_quant) > 25

tutorials/calibration_flow/gptq_like.py

+38-28
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import torch
3434
from torch.utils._pytree import tree_flatten, tree_unflatten
3535

36+
from torchao.core.config import AOBaseConfig
3637
from torchao.dtypes import (
3738
to_affine_quantized_intx,
3839
to_affine_quantized_intx_static,
@@ -47,6 +48,9 @@
4748
to_linear_activation_quantized,
4849
)
4950
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
51+
from torchao.quantization.transform_module import (
52+
register_quantize_module_handler,
53+
)
5054
from torchao.quantization.utils import compute_error
5155

5256
torch.manual_seed(0)
@@ -252,36 +256,42 @@ def _register_forward_pre_hook(module: torch.nn.Module):
252256
)
253257

254258

255-
# using a function to align with the API in quant_api
256-
def apply_activation_static_weight_quant():
257-
def _apply_activation_static_weight_quant(observed_linear):
258-
target_dtype = torch.uint8
259-
260-
# we can quantize the weight here as well
259+
class ApplyActivationStaticWeightQuantConfig(AOBaseConfig):
260+
pass
261261

262-
# activation quantization
263-
act_scale, act_zero_point = (
264-
observed_linear.input_scale,
265-
observed_linear.input_zp,
266-
)
267-
input_quant_func = lambda x: to_affine_quantized_intx_static(
268-
x, act_scale, act_zero_point, x.shape, target_dtype
269-
)
270-
# for demo purpose only, we quantize the weight here
271-
weight = observed_linear.weight
272-
weight = to_affine_quantized_intx(
273-
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
274-
)
275-
observed_linear.weight = torch.nn.Parameter(
276-
to_linear_activation_quantized(weight, input_quant_func),
277-
requires_grad=False,
278-
)
279262

280-
del observed_linear.input_scale
281-
del observed_linear.input_zp
282-
return observed_linear
263+
# using a function to align with the API in quant_api
264+
@register_quantize_module_handler(ApplyActivationStaticWeightQuantConfig)
265+
def _apply_activation_static_weight_quant_transform(
266+
module: torch.nn.Module,
267+
config: ApplyActivationStaticWeightQuantConfig,
268+
):
269+
observed_linear = module
270+
target_dtype = torch.uint8
271+
272+
# we can quantize the weight here as well
273+
274+
# activation quantization
275+
act_scale, act_zero_point = (
276+
observed_linear.input_scale,
277+
observed_linear.input_zp,
278+
)
279+
input_quant_func = lambda x: to_affine_quantized_intx_static(
280+
x, act_scale, act_zero_point, x.shape, target_dtype
281+
)
282+
# for demo purpose only, we quantize the weight here
283+
weight = observed_linear.weight
284+
weight = to_affine_quantized_intx(
285+
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
286+
)
287+
observed_linear.weight = torch.nn.Parameter(
288+
to_linear_activation_quantized(weight, input_quant_func),
289+
requires_grad=False,
290+
)
283291

284-
return _apply_activation_static_weight_quant
292+
del observed_linear.input_scale
293+
del observed_linear.input_zp
294+
return observed_linear
285295

286296

287297
example_inputs = (torch.randn(32, 64),)
@@ -298,7 +308,7 @@ def _apply_activation_static_weight_quant(observed_linear):
298308

299309
# just quantizing activation since we only observed quantization, this could be extended to support
300310
# quantizing weight as well
301-
quantize_(m, apply_activation_static_weight_quant(), _is_linear)
311+
quantize_(m, ApplyActivationStaticWeightQuantConfig(), _is_linear)
302312
for l in m.modules():
303313
if isinstance(l, torch.nn.Linear):
304314
assert isinstance(l.weight, LinearActivationQuantizedTensor)

0 commit comments

Comments
 (0)