5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import logging
8
+ from enum import Enum , auto
8
9
from typing import Optional , Tuple
9
10
10
11
import torch
11
12
from torch .utils ._python_dispatch import return_and_correct_aliasing
12
13
13
14
from torchao .dtypes .affine_quantized_tensor import (
15
+ AffineQuantizedTensor ,
16
+ get_tensor_impl_constructor ,
14
17
register_layout ,
15
18
)
16
19
from torchao .dtypes .affine_quantized_tensor_ops import (
19
22
from torchao .dtypes .utils import AQTTensorImpl , Layout
20
23
from torchao .quantization .quant_primitives import (
21
24
ZeroPointDomain ,
25
+ MappingType ,
26
+ choose_qparams_affine ,
27
+ quantize_affine ,
28
+ )
29
+
30
+ from torchao .utils import (
31
+ TORCH_VERSION_AT_LEAST_2_6 ,
22
32
)
23
33
24
34
logger = logging .getLogger (__name__ )
31
41
handler .setFormatter (formatter )
32
42
logger .addHandler (handler )
33
43
44
+ class Target (Enum ):
45
+ """Enum that indicates the backend target"""
46
+
47
+ NATIVE = auto ()
48
+ ATEN = auto ()
49
+
50
+ def target_from_str (target : str ) -> Target :
51
+ if target .lower () == "native" :
52
+ return Target .NATIVE
53
+ elif target .lower () == "aten" :
54
+ return Target .ATEN
55
+ else :
56
+ raise ValueError (f"Invalid target: { target } " )
34
57
35
58
class PackedLinearInt8DynamicActivationIntxWeightLayout (Layout ):
36
59
bit_width : Optional [int ]
37
60
group_size : Optional [int ]
38
61
has_weight_zeros : Optional [bool ]
62
+ # The target platform for the layout, 'native' or 'aten'
63
+ target : Optional [Target ]
39
64
40
65
def __init__ (
41
66
self ,
42
67
bit_width : Optional [int ] = None ,
43
68
group_size : Optional [int ] = None ,
44
69
has_weight_zeros : Optional [bool ] = None ,
70
+ target : Optional [str ] = "native" ,
45
71
):
46
72
if bit_width is not None :
47
73
assert bit_width >= 1 and bit_width <= 8 , "bit_width must be 1 to 8"
@@ -51,6 +77,7 @@ def __init__(
51
77
self .bit_width = bit_width
52
78
self .group_size = group_size
53
79
self .has_weight_zeros = has_weight_zeros
80
+ self .target = target_from_str (target )
54
81
55
82
if not self .has_params_set ():
56
83
assert (
@@ -60,13 +87,14 @@ def __init__(
60
87
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"
61
88
62
89
def extra_repr (self ):
63
- return f"group_size={ self .group_size } , bit_width={ self .bit_width } , has_weight_zeros={ self .has_weight_zeros } "
90
+ return f"group_size={ self .group_size } , bit_width={ self .bit_width } , has_weight_zeros={ self .has_weight_zeros } , target= { self . target } "
64
91
65
92
def has_params_set (self ) -> bool :
66
93
return (
67
94
(self .bit_width is not None )
68
95
and (self .group_size is not None )
69
96
and (self .has_weight_zeros is not None )
97
+ and (self .target is not None )
70
98
)
71
99
72
100
@@ -125,9 +153,11 @@ def from_plain(
125
153
scale : torch .Tensor ,
126
154
zero_point : Optional [torch .Tensor ],
127
155
layout : Layout ,
156
+ bias : Optional [torch .Tensor ] = None ,
128
157
):
129
158
assert isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
130
159
assert layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
160
+ assert layout .target in {Target .NATIVE , Target .ATEN }, f"Unexpected target: { layout .target } "
131
161
132
162
# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
133
163
# when AOTI supports int
@@ -136,6 +166,13 @@ def from_plain(
136
166
n_tensor = torch .empty (0 , n , dtype = torch .int8 )
137
167
k_tensor = torch .empty (0 , k , dtype = torch .int8 )
138
168
169
+ if layout .target == Target .ATEN :
170
+ assert TORCH_VERSION_AT_LEAST_2_6 , f"aten target is requires torch version > 2.6.0"
171
+ int_data = int_data .add (8 )
172
+ int_data = (int_data [::,1 ::2 ] << 4 | int_data [::,::2 ] ).to (torch .uint8 )
173
+ packed_weight = torch .ops .aten ._dyn_quant_pack_4bit_weight (int_data , scale , bias , layout .group_size , k , n )
174
+ return cls (packed_weight , layout , group_size_tensor , n_tensor , k_tensor )
175
+
139
176
if layout .has_weight_zeros :
140
177
args = [
141
178
int_data .to (torch .int8 ),
@@ -211,16 +248,13 @@ def __tensor_unflatten__(
211
248
def _linear_check (input_tensor , weight_tensor , bias ):
212
249
layout = weight_tensor .tensor_impl .get_layout ()
213
250
return isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout ) and (
214
- bias is None
251
+ bias is None or layout . target == Target . ATEN # Aten target allows bias
215
252
)
216
253
217
254
218
255
def _linear_impl (input_tensor , weight_tensor , bias ):
219
- assert (
220
- bias is None
221
- ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl"
222
256
223
- def _impl_2d (input_tensor , weight_tensor ):
257
+ def _impl_2d_native (input_tensor , weight_tensor ):
224
258
assert input_tensor .dim () == 2
225
259
assert weight_tensor .dim () == 2
226
260
@@ -255,6 +289,31 @@ def _impl_2d(input_tensor, weight_tensor):
255
289
torch .ops .torchao , f"_linear_8bit_act_{ bit_width } bit{ wzp_suffix } _weight"
256
290
)(* args )
257
291
292
+ def _impl_2d_aten (input_tensor , weight_tensor ):
293
+ assert input_tensor .dim () == 2
294
+ assert weight_tensor .dim () == 2
295
+
296
+ m , k = input_tensor .shape
297
+ n , k_ = weight_tensor .shape
298
+ assert k_ == k
299
+ group_size = weight_tensor .tensor_impl .get_layout ().group_size
300
+ packed_weight = weight_tensor .tensor_impl .packed_weight
301
+ return torch .ops .aten ._dyn_quant_matmul_4bit (
302
+ input_tensor , packed_weight , group_size , k_ , n )
303
+
304
+ target = weight_tensor .tensor_impl .get_layout ().target
305
+
306
+ if target == Target .ATEN :
307
+ assert (
308
+ TORCH_VERSION_AT_LEAST_2_6 == 1
309
+ ), "Target.ATEN requires torch >= 2.6.0"
310
+ _impl_2d = _impl_2d_aten
311
+ elif target == Target .NATIVE :
312
+ _impl_2d = _impl_2d_native
313
+ assert (
314
+ bias is None
315
+ ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "
316
+
258
317
if input_tensor .dim () == 2 :
259
318
return _impl_2d (input_tensor , weight_tensor )
260
319
@@ -268,8 +327,82 @@ def _impl_2d(input_tensor, weight_tensor):
268
327
res = res .reshape (* lead_shape , m , n )
269
328
return res
270
329
271
-
272
330
register_aqt_quantized_linear_dispatch (
273
331
_linear_check ,
274
332
_linear_impl ,
275
333
)
334
+
335
+
336
+ class PackedLinearInt8DynamicActivationIntxWeightAtenTensor (AffineQuantizedTensor ):
337
+ """
338
+ PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class.
339
+ """
340
+
341
+ @classmethod
342
+ def from_hp_to_intx (
343
+ cls ,
344
+ input_float : torch .Tensor ,
345
+ mapping_type : MappingType ,
346
+ block_size : Tuple [int , ...],
347
+ target_dtype : torch .dtype ,
348
+ quant_min : Optional [int ] = None ,
349
+ quant_max : Optional [int ] = None ,
350
+ eps : Optional [float ] = None ,
351
+ scale_dtype : Optional [torch .dtype ] = None ,
352
+ zero_point_dtype : Optional [torch .dtype ] = None ,
353
+ preserve_zero : bool = True ,
354
+ zero_point_domain : Optional [ZeroPointDomain ] = ZeroPointDomain .INT ,
355
+ _layout : Layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
356
+ use_hqq : bool = False ,
357
+ bias : Optional [torch .Tensor ] = None
358
+ ):
359
+ assert use_hqq == False , f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
360
+ assert isinstance (
361
+ _layout , PackedLinearInt8DynamicActivationIntxWeightLayout ), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided { _layout } "
362
+ assert _layout .target == Target .ATEN , f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
363
+ original_shape = input_float .shape
364
+ input_float = _layout .pre_process (input_float )
365
+
366
+ scale , zero_point = choose_qparams_affine (
367
+ input_float ,
368
+ mapping_type ,
369
+ block_size ,
370
+ target_dtype ,
371
+ quant_min ,
372
+ quant_max ,
373
+ eps ,
374
+ scale_dtype ,
375
+ zero_point_dtype ,
376
+ preserve_zero ,
377
+ zero_point_domain ,
378
+ )
379
+ # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
380
+ # TODO should probably consolidate ZeroPointDomain.NONE and None
381
+ if zero_point_domain is None or zero_point_domain == ZeroPointDomain .NONE :
382
+ zero_point = None
383
+ data = quantize_affine (
384
+ input_float ,
385
+ block_size ,
386
+ scale ,
387
+ zero_point ,
388
+ target_dtype ,
389
+ quant_min ,
390
+ quant_max ,
391
+ zero_point_domain ,
392
+ )
393
+ # Note: output will be uint8 tensor for sub byte tensors for now
394
+
395
+ data = _layout .post_process (data )
396
+ tensor_impl_ctr = get_tensor_impl_constructor (type (_layout ))
397
+ tensor_impl = tensor_impl_ctr (data , scale , zero_point , _layout , bias )
398
+ return cls (
399
+ tensor_impl ,
400
+ block_size ,
401
+ original_shape ,
402
+ quant_min ,
403
+ quant_max ,
404
+ zero_point_domain ,
405
+ dtype = input_float .dtype ,
406
+ )
407
+
408
+ to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor .from_hp_to_intx
0 commit comments