5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
from dataclasses import dataclass
8
- from typing import Any , Callable , List , Optional , Union
8
+ from typing import Any , List , Optional , Union
9
9
10
10
import torch
11
11
12
+ from torchao .core .config import AOBaseWorkflowConfig
13
+ from torchao .quantization ._transform_module import (
14
+ register_quantize_module_handler ,
15
+ )
12
16
from torchao .quantization .granularity import (
13
17
Granularity ,
14
18
PerAxis ,
@@ -239,12 +243,26 @@ def __setattr__(self, name: str, value: Any):
239
243
super ().__setattr__ (name , value )
240
244
241
245
242
- def intx_quantization_aware_training (
243
- activation_config : Optional [FakeQuantizeConfig ] = None ,
244
- weight_config : Optional [FakeQuantizeConfig ] = None ,
245
- ) -> Callable :
246
+ @dataclass
247
+ class IntXQuantizationAwareTrainingConfig (AOBaseWorkflowConfig ):
248
+ activation_config : Optional [FakeQuantizeConfig ] = None
249
+ weight_config : Optional [FakeQuantizeConfig ] = None
250
+
251
+
252
+ # for BC
253
+ intx_quantization_aware_training = IntXQuantizationAwareTrainingConfig
254
+
255
+
256
+ @register_quantize_module_handler (IntXQuantizationAwareTrainingConfig )
257
+ def _intx_quantization_aware_training_transform (
258
+ module : torch .nn .Module ,
259
+ config : IntXQuantizationAwareTrainingConfig ,
260
+ ) -> torch .nn .Module :
246
261
"""
247
- Return a function that applies fake quantization to a `torch.nn.Module`.
262
+ THIS IS NOT A PUBLIC API - any usage of this outside of torchao
263
+ can break at any time.
264
+
265
+ Apply fake quantization to a `torch.nn.Module`.
248
266
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
249
267
250
268
Example usage::
@@ -267,37 +285,32 @@ def intx_quantization_aware_training(
267
285
`torch.nn.Embedding` with an activation config, then we will raise
268
286
ValueError as these are not supported.
269
287
"""
270
-
271
- def _insert_fake_quantize (mod : torch .nn .Module ):
272
- """
273
- Swap the given module with its corresponding fake quantized version.
274
- """
275
- from .embedding import FakeQuantizedEmbedding
276
- from .linear import FakeQuantizedLinear
277
-
278
- if isinstance (mod , torch .nn .Linear ):
279
- return FakeQuantizedLinear .from_linear (
280
- mod ,
281
- activation_config ,
282
- weight_config ,
283
- )
284
- elif isinstance (mod , torch .nn .Embedding ):
285
- if activation_config is not None :
286
- raise ValueError (
287
- "Activation fake quantization is not supported for embedding"
288
- )
289
- return FakeQuantizedEmbedding .from_embedding (mod , weight_config )
290
- else :
288
+ from .embedding import FakeQuantizedEmbedding
289
+ from .linear import FakeQuantizedLinear
290
+
291
+ mod = module
292
+ activation_config = config .activation_config
293
+ weight_config = config .weight_config
294
+
295
+ if isinstance (mod , torch .nn .Linear ):
296
+ return FakeQuantizedLinear .from_linear (
297
+ mod ,
298
+ activation_config ,
299
+ weight_config ,
300
+ )
301
+ elif isinstance (mod , torch .nn .Embedding ):
302
+ if activation_config is not None :
291
303
raise ValueError (
292
- "Module of type '%s' does not have QAT support" % type ( mod )
304
+ "Activation fake quantization is not supported for embedding"
293
305
)
306
+ return FakeQuantizedEmbedding .from_embedding (mod , weight_config )
307
+ else :
308
+ raise ValueError ("Module of type '%s' does not have QAT support" % type (mod ))
294
309
295
- return _insert_fake_quantize
296
310
297
-
298
- def from_intx_quantization_aware_training () -> Callable :
311
+ class FromIntXQuantizationAwareTrainingConfig (AOBaseWorkflowConfig ):
299
312
"""
300
- Return a function that converts a model with fake quantized modules,
313
+ Object that knows how to convert a model with fake quantized modules,
301
314
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
302
315
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
303
316
back to model with the original, corresponding modules without
@@ -313,22 +326,31 @@ def from_intx_quantization_aware_training() -> Callable:
313
326
)
314
327
"""
315
328
316
- def _remove_fake_quantize (mod : torch .nn .Module ):
317
- """
318
- If the given module is a fake quantized module, return the original
319
- corresponding version of the module without fake quantization.
320
- """
321
- from .embedding import FakeQuantizedEmbedding
322
- from .linear import FakeQuantizedLinear
329
+ pass
330
+
331
+
332
+ # for BC
333
+ from_intx_quantization_aware_training = FromIntXQuantizationAwareTrainingConfig
323
334
324
- if isinstance (mod , FakeQuantizedLinear ):
325
- return mod .to_linear ()
326
- elif isinstance (mod , FakeQuantizedEmbedding ):
327
- return mod .to_embedding ()
328
- else :
329
- return mod
330
335
331
- return _remove_fake_quantize
336
+ @register_quantize_module_handler (FromIntXQuantizationAwareTrainingConfig )
337
+ def _from_intx_quantization_aware_training_transform (
338
+ mod : torch .nn .Module ,
339
+ config : FromIntXQuantizationAwareTrainingConfig ,
340
+ ) -> torch .nn .Module :
341
+ """
342
+ If the given module is a fake quantized module, return the original
343
+ corresponding version of the module without fake quantization.
344
+ """
345
+ from .embedding import FakeQuantizedEmbedding
346
+ from .linear import FakeQuantizedLinear
347
+
348
+ if isinstance (mod , FakeQuantizedLinear ):
349
+ return mod .to_linear ()
350
+ elif isinstance (mod , FakeQuantizedEmbedding ):
351
+ return mod .to_embedding ()
352
+ else :
353
+ return mod
332
354
333
355
334
356
class ComposableQATQuantizer (TwoStepQuantizer ):
0 commit comments