1
1
import torch
2
- import torch .nn .functional as F
3
2
3
+ from torchao .dtypes import (
4
+ TensorCoreTiledLayout ,
5
+ to_affine_quantized_intx ,
6
+ )
7
+ from torchao .dtypes .uintx .uintx_layout import _DTYPE_TO_BIT_WIDTH , UintxLayout
8
+ from torchao .quantization import to_weight_tensor_with_linear_activation_scale_metadata
4
9
from torchao .quantization .granularity import PerGroup
10
+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
5
11
from torchao .quantization .quant_primitives import (
12
+ _DTYPE_TO_QVALUE_BOUNDS ,
6
13
MappingType ,
7
14
ZeroPointDomain ,
8
- _DTYPE_TO_QVALUE_BOUNDS ,
9
15
)
10
- from torchao .quantization import to_weight_tensor_with_linear_activation_scale_metadata
11
- from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
12
- from torchao .dtypes .uintx .uintx_layout import _DTYPE_TO_BIT_WIDTH , UintxLayout
13
- from torchao .dtypes import (
14
- to_affine_quantized_intx ,
15
- TensorCoreTiledLayout ,
16
+
17
+ from .core import (
18
+ AWQObservedLinear ,
19
+ AWQObserver ,
16
20
)
17
- from .core import (
18
- AWQObserver ,
19
- AWQObservedLinear ,
20
- )
21
21
22
+ assert (
23
+ len (_DTYPE_TO_BIT_WIDTH ) > 0
24
+ ), "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+"
22
25
23
- assert len (_DTYPE_TO_BIT_WIDTH ) > 0 , "Error importing low bit torch.uint dtypes. Please upgrade to torch 2.3+"
24
26
25
- def insert_awq_observer_ (model : torch .nn .Module , n_validation_examples : int , validation_sequence_len : int , quant_dtype : torch .dtype = torch .uint4 , scale_search_space_size : int = 20 , group_size : int = 128 ):
27
+ def insert_awq_observer_ (
28
+ model : torch .nn .Module ,
29
+ n_validation_examples : int ,
30
+ validation_sequence_len : int ,
31
+ quant_dtype : torch .dtype = torch .uint4 ,
32
+ scale_search_space_size : int = 20 ,
33
+ group_size : int = 128 ,
34
+ ):
26
35
"""
27
36
Inserts AWQObserver into Linear layers of a given model.
28
37
@@ -35,58 +44,75 @@ def insert_awq_observer_(model: torch.nn.Module, n_validation_examples: int, val
35
44
group_size: Quantization granularity. Use -1 for channel wise quantization
36
45
"""
37
46
_is_linear = lambda m , fqn : isinstance (m , torch .nn .Linear )
38
- assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch .uint8 , "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
47
+ assert (
48
+ quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch .uint8
49
+ ), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
39
50
# AQT config
40
51
mapping_type = MappingType .ASYMMETRIC
41
52
quantization_granularity = PerGroup (group_size )
42
53
quant_min = 0
43
- quant_max = 255 if quant_dtype == torch .uint8 else 2 ** _DTYPE_TO_BIT_WIDTH [quant_dtype ] - 1
54
+ quant_max = (
55
+ 255 if quant_dtype == torch .uint8 else 2 ** _DTYPE_TO_BIT_WIDTH [quant_dtype ] - 1
56
+ )
44
57
eps = torch .finfo (torch .float32 ).eps
45
58
preserve_zero = True
46
59
zero_point_dtype = torch .int64
47
60
zero_point_domain = ZeroPointDomain .INT
48
-
49
61
50
62
def replace_with_observer (layer ):
51
63
# creates observer and replaces linear layers with AWQObservedLinear layers
52
64
observer = AWQObserver (
53
65
layer .weight ,
54
- layer .bias ,
55
- quantization_granularity ,
66
+ layer .bias ,
67
+ quantization_granularity ,
56
68
mapping_type ,
57
- quant_dtype ,
69
+ quant_dtype ,
58
70
n_validation_examples ,
59
71
validation_sequence_len ,
60
72
scale_search_space_size ,
61
- preserve_zero = preserve_zero ,
62
- zero_point_domain = zero_point_domain ,
63
- zero_point_dtype = zero_point_dtype ,
73
+ preserve_zero = preserve_zero ,
74
+ zero_point_domain = zero_point_domain ,
75
+ zero_point_dtype = zero_point_dtype ,
64
76
quant_min = quant_min ,
65
- quant_max = quant_max ,
66
- eps = eps )
77
+ quant_max = quant_max ,
78
+ eps = eps ,
79
+ )
67
80
return AWQObservedLinear .from_float (layer , observer )
81
+
68
82
_replace_with_custom_fn_if_matches_filter (model , replace_with_observer , _is_linear )
69
83
84
+
70
85
def _observed_linear_subclass_inserter (constructor ):
71
86
"""
72
87
Replaces unquantized AWQObservedLinear instances with quantized linear instances.
73
88
74
89
Args:
75
90
constructor: the function which applies quantization to the AWQObservedLinear layer
76
91
"""
92
+
77
93
def insert_subclass (observed_linear ):
78
94
# creates the new linear layer using constructor
79
- linear = torch .nn .Linear (observed_linear .in_features , observed_linear .out_features , observed_linear .bias != None , device = observed_linear .weight .device , dtype = observed_linear .weight .dtype )
80
- linear .weight = torch .nn .Parameter (constructor (observed_linear ), requires_grad = False )
95
+ linear = torch .nn .Linear (
96
+ observed_linear .in_features ,
97
+ observed_linear .out_features ,
98
+ observed_linear .bias != None ,
99
+ device = observed_linear .weight .device ,
100
+ dtype = observed_linear .weight .dtype ,
101
+ )
102
+ linear .weight = torch .nn .Parameter (
103
+ constructor (observed_linear ), requires_grad = False
104
+ )
81
105
linear .bias = observed_linear .bias
82
106
return linear
83
107
84
108
return insert_subclass
85
-
86
109
87
- def awq_uintx (quant_dtype : torch .dtype = torch .uint4 ,
88
- group_size : int = 64 ,
89
- use_hqq : bool = False ,):
110
+
111
+ def awq_uintx (
112
+ quant_dtype : torch .dtype = torch .uint4 ,
113
+ group_size : int = 64 ,
114
+ use_hqq : bool = False ,
115
+ ):
90
116
"""
91
117
Quantizes linear layers when passed into quantize_()
92
118
@@ -95,8 +121,10 @@ def awq_uintx(quant_dtype: torch.dtype = torch.uint4,
95
121
group_size: Quantization granularity. Use -1 for channel wise quantization
96
122
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
97
123
"""
98
- assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch .uint8 , "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
99
-
124
+ assert (
125
+ quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch .uint8
126
+ ), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
127
+
100
128
def weight_quant_func (observed_linear ):
101
129
equalization_scale = observed_linear .act_obs .calculate_qparams ()
102
130
# AQT config
@@ -114,24 +142,28 @@ def weight_quant_func(observed_linear):
114
142
zero_point_dtype = torch .int64
115
143
zero_point_domain = ZeroPointDomain .INT
116
144
_layout = UintxLayout (quant_dtype )
117
-
145
+
118
146
mapping_type = MappingType .ASYMMETRIC
119
147
block_size = (1 , group_size )
120
148
quant_min = _DTYPE_TO_QVALUE_BOUNDS [quant_dtype ][0 ]
121
149
quant_max = _DTYPE_TO_QVALUE_BOUNDS [quant_dtype ][1 ]
122
150
qw = to_affine_quantized_intx (
123
151
observed_linear .weight * equalization_scale ,
124
152
mapping_type ,
125
- block_size ,
126
- target_dtype , quant_min ,
127
- quant_max , eps ,
153
+ block_size ,
154
+ target_dtype ,
155
+ quant_min ,
156
+ quant_max ,
157
+ eps ,
128
158
zero_point_dtype = zero_point_dtype ,
129
159
preserve_zero = preserve_zero ,
130
160
zero_point_domain = zero_point_domain ,
131
161
_layout = _layout ,
132
- use_hqq = use_hqq
162
+ use_hqq = use_hqq ,
133
163
)
134
-
135
- return to_weight_tensor_with_linear_activation_scale_metadata (qw , equalization_scale )
136
-
164
+
165
+ return to_weight_tensor_with_linear_activation_scale_metadata (
166
+ qw , equalization_scale
167
+ )
168
+
137
169
return _observed_linear_subclass_inserter (weight_quant_func )
0 commit comments