5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
from dataclasses import dataclass
8
- from typing import Any , Callable , List , Optional , Union
8
+ from typing import Any , List , Optional , Union
9
9
10
10
import torch
11
11
12
+ from torchao .core .config import AOBaseConfig
12
13
from torchao .quantization .granularity import (
13
14
Granularity ,
14
15
PerAxis ,
22
23
TorchAODType ,
23
24
ZeroPointDomain ,
24
25
)
26
+ from torchao .quantization .transform_module import (
27
+ register_quantize_module_handler ,
28
+ )
25
29
from torchao .quantization .unified import TwoStepQuantizer
26
30
27
31
@@ -241,12 +245,26 @@ def __setattr__(self, name: str, value: Any):
241
245
super ().__setattr__ (name , value )
242
246
243
247
244
- def intx_quantization_aware_training (
245
- activation_config : Optional [FakeQuantizeConfig ] = None ,
246
- weight_config : Optional [FakeQuantizeConfig ] = None ,
247
- ) -> Callable :
248
+ @dataclass
249
+ class IntXQuantizationAwareTrainingConfig (AOBaseConfig ):
250
+ activation_config : Optional [FakeQuantizeConfig ] = None
251
+ weight_config : Optional [FakeQuantizeConfig ] = None
252
+
253
+
254
+ # for BC
255
+ intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
256
+
257
+
258
+ @register_quantize_module_handler (IntXQuantizationAwareTrainingConfig )
259
+ def _intx_quantization_aware_training_transform (
260
+ module : torch .nn .Module ,
261
+ config : IntXQuantizationAwareTrainingConfig ,
262
+ ) -> torch .nn .Module :
248
263
"""
249
- Return a function that applies fake quantization to a `torch.nn.Module`.
264
+ THIS IS NOT A PUBLIC API - any usage of this outside of torchao
265
+ can break at any time.
266
+
267
+ Apply fake quantization to a `torch.nn.Module`.
250
268
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
251
269
252
270
Example usage::
@@ -261,45 +279,40 @@ def intx_quantization_aware_training(
261
279
)
262
280
quantize_(
263
281
model,
264
- intx_quantization_aware_training (activation_config, weight_config),
282
+ IntXQuantizationAwareTrainingConfig (activation_config, weight_config),
265
283
)
266
284
267
285
Note: If the returned function is applied on a module that is not
268
286
`torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
269
287
`torch.nn.Embedding` with an activation config, then we will raise
270
288
ValueError as these are not supported.
271
289
"""
272
-
273
- def _insert_fake_quantize (mod : torch .nn .Module ):
274
- """
275
- Swap the given module with its corresponding fake quantized version.
276
- """
277
- from .embedding import FakeQuantizedEmbedding
278
- from .linear import FakeQuantizedLinear
279
-
280
- if isinstance (mod , torch .nn .Linear ):
281
- return FakeQuantizedLinear .from_linear (
282
- mod ,
283
- activation_config ,
284
- weight_config ,
285
- )
286
- elif isinstance (mod , torch .nn .Embedding ):
287
- if activation_config is not None :
288
- raise ValueError (
289
- "Activation fake quantization is not supported for embedding"
290
- )
291
- return FakeQuantizedEmbedding .from_embedding (mod , weight_config )
292
- else :
290
+ from .embedding import FakeQuantizedEmbedding
291
+ from .linear import FakeQuantizedLinear
292
+
293
+ mod = module
294
+ activation_config = config .activation_config
295
+ weight_config = config .weight_config
296
+
297
+ if isinstance (mod , torch .nn .Linear ):
298
+ return FakeQuantizedLinear .from_linear (
299
+ mod ,
300
+ activation_config ,
301
+ weight_config ,
302
+ )
303
+ elif isinstance (mod , torch .nn .Embedding ):
304
+ if activation_config is not None :
293
305
raise ValueError (
294
- "Module of type '%s' does not have QAT support" % type ( mod )
306
+ "Activation fake quantization is not supported for embedding"
295
307
)
308
+ return FakeQuantizedEmbedding .from_embedding (mod , weight_config )
309
+ else :
310
+ raise ValueError ("Module of type '%s' does not have QAT support" % type (mod ))
296
311
297
- return _insert_fake_quantize
298
312
299
-
300
- def from_intx_quantization_aware_training () -> Callable :
313
+ class FromIntXQuantizationAwareTrainingConfig (AOBaseConfig ):
301
314
"""
302
- Return a function that converts a model with fake quantized modules,
315
+ Object that knows how to convert a model with fake quantized modules,
303
316
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
304
317
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
305
318
back to model with the original, corresponding modules without
@@ -311,26 +324,35 @@ def from_intx_quantization_aware_training() -> Callable:
311
324
from torchao.quantization import quantize_
312
325
quantize_(
313
326
model_with_fake_quantized_linears,
314
- from_intx_quantization_aware_training (),
327
+ FromIntXQuantizationAwareTrainingConfig (),
315
328
)
316
329
"""
317
330
318
- def _remove_fake_quantize (mod : torch .nn .Module ):
319
- """
320
- If the given module is a fake quantized module, return the original
321
- corresponding version of the module without fake quantization.
322
- """
323
- from .embedding import FakeQuantizedEmbedding
324
- from .linear import FakeQuantizedLinear
331
+ pass
332
+
333
+
334
+ # for BC
335
+ from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
325
336
326
- if isinstance (mod , FakeQuantizedLinear ):
327
- return mod .to_linear ()
328
- elif isinstance (mod , FakeQuantizedEmbedding ):
329
- return mod .to_embedding ()
330
- else :
331
- return mod
332
337
333
- return _remove_fake_quantize
338
+ @register_quantize_module_handler (FromIntXQuantizationAwareTrainingConfig )
339
+ def _from_intx_quantization_aware_training_transform (
340
+ mod : torch .nn .Module ,
341
+ config : FromIntXQuantizationAwareTrainingConfig ,
342
+ ) -> torch .nn .Module :
343
+ """
344
+ If the given module is a fake quantized module, return the original
345
+ corresponding version of the module without fake quantization.
346
+ """
347
+ from .embedding import FakeQuantizedEmbedding
348
+ from .linear import FakeQuantizedLinear
349
+
350
+ if isinstance (mod , FakeQuantizedLinear ):
351
+ return mod .to_linear ()
352
+ elif isinstance (mod , FakeQuantizedEmbedding ):
353
+ return mod .to_embedding ()
354
+ else :
355
+ return mod
334
356
335
357
336
358
class ComposableQATQuantizer (TwoStepQuantizer ):
0 commit comments