11import torch
2- import torch .nn .functional as F
32
3+ from torchao .dtypes import (
4+ TensorCoreTiledLayout ,
5+ to_affine_quantized_intx ,
6+ )
7+ from torchao .dtypes .uintx .uintx_layout import _DTYPE_TO_BIT_WIDTH , UintxLayout
8+ from torchao .quantization import to_weight_tensor_with_linear_activation_scale_metadata
49from torchao .quantization .granularity import PerGroup
10+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
511from torchao .quantization .quant_primitives import (
12+ _DTYPE_TO_QVALUE_BOUNDS ,
613 MappingType ,
714 ZeroPointDomain ,
8- _DTYPE_TO_QVALUE_BOUNDS ,
915)
10- from torchao .quantization import to_weight_tensor_with_linear_activation_scale_metadata
11- from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
12- from torchao .dtypes .uintx .uintx_layout import _DTYPE_TO_BIT_WIDTH , UintxLayout
13- from torchao .dtypes import (
14- to_affine_quantized_intx ,
15- TensorCoreTiledLayout ,
16+
17+ from .core import (
18+ AWQObservedLinear ,
19+ AWQObserver ,
1620)
17- from .core import (
18- AWQObserver ,
19- AWQObservedLinear ,
20- )
2121
22+ assert (
23+ len (_DTYPE_TO_BIT_WIDTH ) > 0
24+ ), "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+"
2225
23- assert len (_DTYPE_TO_BIT_WIDTH ) > 0 , "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+"
2426
25- def insert_awq_observer_ (model : torch .nn .Module , n_validation_examples : int , validation_sequence_len : int , quant_dtype : torch .dtype = torch .uint4 , scale_search_space_size : int = 20 , group_size : int = 128 ):
27+ def insert_awq_observer_ (
28+ model : torch .nn .Module ,
29+ n_validation_examples : int ,
30+ validation_sequence_len : int ,
31+ quant_dtype : torch .dtype = torch .uint4 ,
32+ scale_search_space_size : int = 20 ,
33+ group_size : int = 128 ,
34+ ):
2635 """
2736 Inserts AWQObserver into Linear layers of a given model.
2837
@@ -35,58 +44,75 @@ def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, val
3544 group_size: Quantization granularity. Use -1 for channel wise quantization
3645 """
3746 _is_linear = lambda m , fqn : isinstance (m , torch .nn .Linear )
38- assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch .uint8 , "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
47+ assert (
48+ quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch .uint8
49+ ), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
3950 # AQT config
4051 mapping_type = MappingType .ASYMMETRIC
4152 quantization_granularity = PerGroup (group_size )
4253 quant_min = 0
43- quant_max = 255 if quant_dtype == torch .uint8 else 2 ** _DTYPE_TO_BIT_WIDTH [quant_dtype ] - 1
54+ quant_max = (
55+ 255 if quant_dtype == torch .uint8 else 2 ** _DTYPE_TO_BIT_WIDTH [quant_dtype ] - 1
56+ )
4457 eps = torch .finfo (torch .float32 ).eps
4558 preserve_zero = True
4659 zero_point_dtype = torch .int64
4760 zero_point_domain = ZeroPointDomain .INT
48-
4961
5062 def replace_with_observer (layer ):
5163 # creates observer and replaces linear layers with AWQObservedLinear layers
5264 observer = AWQObserver (
5365 layer .weight ,
54- layer .bias ,
55- quantization_granularity ,
66+ layer .bias ,
67+ quantization_granularity ,
5668 mapping_type ,
57- quant_dtype ,
69+ quant_dtype ,
5870 n_validation_examples ,
5971 validation_sequence_len ,
6072 scale_search_space_size ,
61- preserve_zero = preserve_zero ,
62- zero_point_domain = zero_point_domain ,
63- zero_point_dtype = zero_point_dtype ,
73+ preserve_zero = preserve_zero ,
74+ zero_point_domain = zero_point_domain ,
75+ zero_point_dtype = zero_point_dtype ,
6476 quant_min = quant_min ,
65- quant_max = quant_max ,
66- eps = eps )
77+ quant_max = quant_max ,
78+ eps = eps ,
79+ )
6780 return AWQObservedLinear .from_float (layer , observer )
81+
6882 _replace_with_custom_fn_if_matches_filter (model , replace_with_observer , _is_linear )
6983
84+
7085def _observed_linear_subclass_inserter (constructor ):
7186 """
7287 Replaces unquantized AWQObservedLinear instances with quantized linear instances.
7388
7489 Args:
7590 constructor: the function which applies quantization to the AWQObservedLinear layer
7691 """
92+
7793 def insert_subclass (observed_linear ):
7894 # creates the new linear layer using constructor
79- linear = torch .nn .Linear (observed_linear .in_features , observed_linear .out_features , observed_linear .bias != None , device = observed_linear .weight .device , dtype = observed_linear .weight .dtype )
80- linear .weight = torch .nn .Parameter (constructor (observed_linear ), requires_grad = False )
95+ linear = torch .nn .Linear (
96+ observed_linear .in_features ,
97+ observed_linear .out_features ,
98+ observed_linear .bias != None ,
99+ device = observed_linear .weight .device ,
100+ dtype = observed_linear .weight .dtype ,
101+ )
102+ linear .weight = torch .nn .Parameter (
103+ constructor (observed_linear ), requires_grad = False
104+ )
81105 linear .bias = observed_linear .bias
82106 return linear
83107
84108 return insert_subclass
85-
86109
87- def awq_uintx (quant_dtype : torch .dtype = torch .uint4 ,
88- group_size : int = 64 ,
89- use_hqq : bool = False ,):
110+
111+ def awq_uintx (
112+ quant_dtype : torch .dtype = torch .uint4 ,
113+ group_size : int = 64 ,
114+ use_hqq : bool = False ,
115+ ):
90116 """
91117 Quantizes linear layers when passed into quantize_()
92118
@@ -95,8 +121,10 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4,
95121 group_size: Quantization granularity. Use -1 for channel wise quantization
96122 weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
97123 """
98- assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch .uint8 , "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
99-
124+ assert (
125+ quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch .uint8
126+ ), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
127+
100128 def weight_quant_func (observed_linear ):
101129 equalization_scale = observed_linear .act_obs .calculate_qparams ()
102130 # AQT config
@@ -114,24 +142,28 @@ def weight_quant_func(observed_linear):
114142 zero_point_dtype = torch .int64
115143 zero_point_domain = ZeroPointDomain .INT
116144 _layout = UintxLayout (quant_dtype )
117-
145+
118146 mapping_type = MappingType .ASYMMETRIC
119147 block_size = (1 , group_size )
120148 quant_min = _DTYPE_TO_QVALUE_BOUNDS [quant_dtype ][0 ]
121149 quant_max = _DTYPE_TO_QVALUE_BOUNDS [quant_dtype ][1 ]
122150 qw = to_affine_quantized_intx (
123151 observed_linear .weight * equalization_scale ,
124152 mapping_type ,
125- block_size ,
126- target_dtype , quant_min ,
127- quant_max , eps ,
153+ block_size ,
154+ target_dtype ,
155+ quant_min ,
156+ quant_max ,
157+ eps ,
128158 zero_point_dtype = zero_point_dtype ,
129159 preserve_zero = preserve_zero ,
130160 zero_point_domain = zero_point_domain ,
131161 _layout = _layout ,
132- use_hqq = use_hqq
162+ use_hqq = use_hqq ,
133163 )
134-
135- return to_weight_tensor_with_linear_activation_scale_metadata (qw , equalization_scale )
136-
164+
165+ return to_weight_tensor_with_linear_activation_scale_metadata (
166+ qw , equalization_scale
167+ )
168+
137169 return _observed_linear_subclass_inserter (weight_quant_func )
0 commit comments