4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from typing import Any , Optional , Tuple
7
+ from typing import Any , Optional
8
8
9
9
import torch
10
10
import torch .nn .functional as F
@@ -196,40 +196,15 @@ def convert(
196
196
"""
197
197
self ._convert_helper (model )
198
198
return model
199
-
200
- @staticmethod
201
- def quantize_weights (
202
- weight : torch .Tensor ,
203
- bit_width : int ,
204
- group_size : int ,
205
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
206
- """
207
- Helper function to quantize weights
208
- """
209
- (qmin , qmax ) = _get_qmin_qmax (bit_width )
210
- (s , zp ) = get_group_qparams_symmetric (
211
- weight , bit_width , group_size
212
- )
213
- from torchao ._executorch_ops import (
214
- _quantized_decomposed_quantize_per_channel_group_wrapper ,
215
- )
216
- q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
217
- weight ,
218
- s ,
219
- zp ,
220
- qmin ,
221
- qmax ,
222
- torch .int8 ,
223
- group_size ,
224
- )
225
- return (q_weight , s , zp )
226
-
227
199
228
200
def _convert_helper (self , module : torch .nn .Module ):
229
201
"""
230
202
Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
231
203
modules with `Int4WeightOnlyEmbedding`
232
204
"""
205
+ from torchao ._executorch_ops import (
206
+ _quantized_decomposed_quantize_per_channel_group_wrapper ,
207
+ )
233
208
234
209
for name , child in module .named_children ():
235
210
if isinstance (child , Int4WeightOnlyQATEmbedding ):
@@ -255,8 +230,20 @@ def _convert_helper(self, module: torch.nn.Module):
255
230
)
256
231
setattr (module , name , quantized_embedding )
257
232
258
- q_weight , s , zp = self .quantize_weights (child .weight , self .bit_width , group_size )
259
233
# Load weights and qparams into quantized embedding
234
+ (qmin , qmax ) = _get_qmin_qmax (self .bit_width )
235
+ (s , zp ) = get_group_qparams_symmetric (
236
+ child .weight , self .bit_width , group_size
237
+ )
238
+ q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper (
239
+ child .weight ,
240
+ s ,
241
+ zp ,
242
+ qmin ,
243
+ qmax ,
244
+ torch .int8 ,
245
+ group_size ,
246
+ )
260
247
quantized_embedding .weight = q_weight
261
248
quantized_embedding .scale = s .to (scale_precision )
262
249
quantized_embedding .zero_point = zp .to (zero_point_precision )
0 commit comments