Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate static quant tutorials to direct configuration #1710

Merged
merged 50 commits into from
Feb 14, 2025
Merged
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
24114ce
Update
vkuzo Jan 22, 2025
5b9d876
Update
vkuzo Jan 22, 2025
1cea42f
Update
vkuzo Jan 22, 2025
138883b
Update
vkuzo Jan 22, 2025
ba045ea
Update
vkuzo Jan 22, 2025
94d9426
Update
vkuzo Jan 22, 2025
b589ce7
Update
vkuzo Jan 23, 2025
aaba2d8
Update
vkuzo Feb 5, 2025
26850da
Update
vkuzo Feb 5, 2025
7caecb1
Update
vkuzo Feb 10, 2025
d42a590
Update
vkuzo Feb 10, 2025
5702ea0
Update
vkuzo Feb 11, 2025
0542402
Update
vkuzo Feb 11, 2025
146ac3b
Update
vkuzo Feb 11, 2025
5f75897
Update
vkuzo Feb 11, 2025
1c9c39f
Update
vkuzo Feb 11, 2025
1ff1f6e
Update
vkuzo Feb 11, 2025
bb253ef
Update
vkuzo Feb 11, 2025
c2ed2da
Update
vkuzo Feb 11, 2025
698989b
Update
vkuzo Feb 11, 2025
6184530
Update
vkuzo Feb 11, 2025
397002e
Update
vkuzo Feb 11, 2025
5514a99
Update
vkuzo Feb 11, 2025
fac3263
Update
vkuzo Feb 11, 2025
1e15950
Update
vkuzo Feb 11, 2025
e9c03e0
Update
vkuzo Feb 11, 2025
f5b7d87
Update
vkuzo Feb 11, 2025
6684b39
Update
vkuzo Feb 11, 2025
4dcb349
Update
vkuzo Feb 12, 2025
d63e657
Update
vkuzo Feb 13, 2025
36c2096
Update
vkuzo Feb 13, 2025
ca7531d
Update
vkuzo Feb 13, 2025
b55b1bb
Update
vkuzo Feb 13, 2025
3aaf5a0
Update
vkuzo Feb 13, 2025
3fd4cfc
Update
vkuzo Feb 13, 2025
ac7e5da
Update
vkuzo Feb 14, 2025
1e152e3
Update
vkuzo Feb 14, 2025
0be10ae
Update
vkuzo Feb 14, 2025
2f0d4e3
Update
vkuzo Feb 14, 2025
e397c47
Update
vkuzo Feb 14, 2025
9eebc4f
Update
vkuzo Feb 14, 2025
81dcff8
Update
vkuzo Feb 14, 2025
f44befc
Update
vkuzo Feb 14, 2025
e534d64
Update
vkuzo Feb 14, 2025
54d3c31
Update
vkuzo Feb 14, 2025
7688b35
Update
vkuzo Feb 14, 2025
e776f11
Update
vkuzo Feb 14, 2025
03fb862
Update
vkuzo Feb 14, 2025
0c09446
Update
vkuzo Feb 14, 2025
1979394
Update
vkuzo Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 65 additions & 49 deletions tutorials/calibration_flow/awq_like.py
Original file line number Diff line number Diff line change
@@ -8,11 +8,13 @@
"""

import copy
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import Tensor

from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
Float8Layout,
to_affine_quantized_floatx_static,
@@ -33,6 +35,9 @@
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.utils import compute_error


@@ -83,61 +88,72 @@ def replacement_fn(m):
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)


@dataclass
class ApplyAWQConfig(AOBaseConfig):
target_dtype: torch.dtype


# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
def apply_awq(target_dtype: torch.dtype):
# target_dtype = torch.uint8
def _apply_awq_to_linear(observed_linear):
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
return to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=None),
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

# activation quantization
# pretend this to be the equalization scale, in reality the `act_obs` should
# be an observer that can caluclate equalization scale
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
equalization_scale = torch.ones_like(equalization_scale)

linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
)
@register_quantize_module_handler(ApplyAWQConfig)
def _apply_awq_transform(
module: torch.nn.Module,
config: ApplyAWQConfig,
):
target_dtype = config.target_dtype
observed_linear = module

linear.weight = torch.nn.Parameter(
to_weight_tensor_with_linear_activation_scale_metadata(
linear.weight, equalization_scale
),
requires_grad=False,
)
# target_dtype = torch.uint8
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
return to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=None),
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

# activation quantization
# pretend this to be the equalization scale, in reality the `act_obs` should
# be an observer that can caluclate equalization scale
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
equalization_scale = torch.ones_like(equalization_scale)

return linear
linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
)

linear.weight = torch.nn.Parameter(
to_weight_tensor_with_linear_activation_scale_metadata(
linear.weight, equalization_scale
),
requires_grad=False,
)

return _apply_awq_to_linear
return linear


######## Test ##########
@@ -201,7 +217,7 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_awq(target_dtype), is_observed_linear)
quantize_(m, ApplyAWQConfig(target_dtype), is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 25
66 changes: 38 additions & 28 deletions tutorials/calibration_flow/gptq_like.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
to_affine_quantized_intx,
to_affine_quantized_intx_static,
@@ -47,6 +48,9 @@
to_linear_activation_quantized,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.utils import compute_error

torch.manual_seed(0)
@@ -252,36 +256,42 @@ def _register_forward_pre_hook(module: torch.nn.Module):
)


# using a function to align with the API in quant_api
def apply_activation_static_weight_quant():
def _apply_activation_static_weight_quant(observed_linear):
target_dtype = torch.uint8

# we can quantize the weight here as well
class ApplyActivationStaticWeightQuantConfig(AOBaseConfig):
pass

# activation quantization
act_scale, act_zero_point = (
observed_linear.input_scale,
observed_linear.input_zp,
)
input_quant_func = lambda x: to_affine_quantized_intx_static(
x, act_scale, act_zero_point, x.shape, target_dtype
)
# for demo purpose only, we quantize the weight here
weight = observed_linear.weight
weight = to_affine_quantized_intx(
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
)
observed_linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(weight, input_quant_func),
requires_grad=False,
)

del observed_linear.input_scale
del observed_linear.input_zp
return observed_linear
# using a function to align with the API in quant_api
@register_quantize_module_handler(ApplyActivationStaticWeightQuantConfig)
def _apply_activation_static_weight_quant_transform(
module: torch.nn.Module,
config: ApplyActivationStaticWeightQuantConfig,
):
observed_linear = module
target_dtype = torch.uint8

# we can quantize the weight here as well

# activation quantization
act_scale, act_zero_point = (
observed_linear.input_scale,
observed_linear.input_zp,
)
input_quant_func = lambda x: to_affine_quantized_intx_static(
x, act_scale, act_zero_point, x.shape, target_dtype
)
# for demo purpose only, we quantize the weight here
weight = observed_linear.weight
weight = to_affine_quantized_intx(
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
)
observed_linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(weight, input_quant_func),
requires_grad=False,
)

return _apply_activation_static_weight_quant
del observed_linear.input_scale
del observed_linear.input_zp
return observed_linear


example_inputs = (torch.randn(32, 64),)
@@ -298,7 +308,7 @@ def _apply_activation_static_weight_quant(observed_linear):

# just quantizing activation since we only observed quantization, this could be extended to support
# quantizing weight as well
quantize_(m, apply_activation_static_weight_quant(), _is_linear)
quantize_(m, ApplyActivationStaticWeightQuantConfig(), _is_linear)
for l in m.modules():
if isinstance(l, torch.nn.Linear):
assert isinstance(l.weight, LinearActivationQuantizedTensor)
131 changes: 75 additions & 56 deletions tutorials/calibration_flow/static_quant.py
Original file line number Diff line number Diff line change
@@ -3,11 +3,13 @@
"""

import copy
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import Tensor

from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
Float8Layout,
to_affine_quantized_floatx_static,
@@ -26,6 +28,9 @@
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.utils import compute_error
from torchao.utils import is_sm_at_least_90

@@ -77,66 +82,74 @@ def replacement_fn(m):
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)


# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
def apply_static_quant(target_dtype: torch.dtype):
# target_dtype = torch.uint8
def _apply_static_quant_to_linear(observed_linear):
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
mm_config = Float8MMConfig(use_fast_accum=True)
return to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=mm_config),
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias
@dataclass
class ApplyStaticQuantConfig(AOBaseConfig):
target_dtype: torch.dtype

linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight), requires_grad=False
)

# activation quantization
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
@register_quantize_module_handler(ApplyStaticQuantConfig)
def _apply_static_quant_transform(
module: torch.nn.Module,
config: ApplyStaticQuantConfig,
):
target_dtype = config.target_dtype
observed_linear = module

# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
input_quant_func = lambda x: to_affine_quantized_intx_static(
x, act_scale, act_zero_point, x.shape, target_dtype
return to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
input_quant_func = lambda x: to_affine_quantized_floatx_static(
x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None)
mm_config = Float8MMConfig(use_fast_accum=True)
return to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=mm_config),
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")
linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(linear.weight, input_quant_func),
requires_grad=False,
)

return linear
linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

return _apply_static_quant_to_linear
linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight), requires_grad=False
)

# activation quantization
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
if target_dtype == torch.uint8:
input_quant_func = lambda x: to_affine_quantized_intx_static(
x, act_scale, act_zero_point, x.shape, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
input_quant_func = lambda x: to_affine_quantized_floatx_static(
x, act_scale, x.shape, target_dtype, Float8Layout(mm_config=None)
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")
linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(linear.weight, input_quant_func),
requires_grad=False,
)

return linear


# alternative for converting observed linear module to quantized linear module
@@ -210,11 +223,17 @@ def from_observed(cls, observed_linear, target_dtype):
return quantized_linear


def apply_static_quant2(target_dtype: torch.dtype):
def _apply_static_quant2(observed_linear):
return QuantizedLinear.from_observed(observed_linear, target_dtype)
@dataclass
class ApplyStaticQuantConfig2(AOBaseConfig):
target_dtype: torch.dtype


return _apply_static_quant2
@register_quantize_module_handler(ApplyStaticQuantConfig2)
def apply_static_quant(
module: torch.nn.Module,
config: ApplyStaticQuantConfig2,
):
return QuantizedLinear.from_observed(module, config.target_dtype)


class ToyLinearModel(torch.nn.Module):
@@ -281,14 +300,14 @@ def test_static_quant(target_dtype: torch.dtype, mapping_type: MappingType):

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_static_quant(target_dtype), is_observed_linear)
quantize_(m, ApplyStaticQuantConfig(target_dtype), is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 25
print("test passed")

# quantized linear as a standalone module
quantize_(m2, apply_static_quant2(target_dtype), is_observed_linear)
quantize_(m2, ApplyStaticQuantConfig2(target_dtype), is_observed_linear)
print("quantized model (quantized module):", m2)
after_quant = m2(*example_inputs)
assert compute_error(before_quant, after_quant) > 25