Skip to content

Commit 7815262

Browse files
authored
[Feat]: Add support for kleidiai quantization schemes (#1447)
1 parent 463a872 commit 7815262

5 files changed

+328
-38
lines changed

torchao/experimental/docs/readme.md

+31
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,37 @@ quantize_(
9898
)
9999
```
100100

101+
KleidiAI Int4 Kernels can be utilized on the Arm platform with PyTorch versions 2.6.0 or later by adjusting the quantization parameters as follows:
102+
103+
```python
104+
from torchao.dtypes import PlainLayout
105+
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
106+
PackedLinearInt8DynamicActivationIntxWeightLayout,
107+
)
108+
from torchao.experimental.quant_api import (
109+
int8_dynamic_activation_intx_weight,
110+
)
111+
from torchao.quantization.granularity import (
112+
PerGroup,
113+
PerRow,
114+
)
115+
from torchao.quantization.quant_api import quantize_
116+
from torchao.quantization.quant_primitives import MappingType
117+
118+
my_model = Model()
119+
120+
quantize_(
121+
my_model,
122+
int8_dynamic_activation_intx_weight(
123+
weight_dtype=torch.int4,
124+
granularity=PerGroup(32), # PerRow() is also supported
125+
has_weight_zeros=True, # Should be True
126+
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR # MappingType.SYMMETRIC can also be used but increases error
127+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"),
128+
),
129+
)
130+
```
131+
101132
If you get stuck, consult
102133
`torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py`
103134
for a working example.

torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py

+140-7
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
from enum import Enum, auto
89
from typing import Optional, Tuple
910

1011
import torch
1112
from torch.utils._python_dispatch import return_and_correct_aliasing
1213

1314
from torchao.dtypes.affine_quantized_tensor import (
15+
AffineQuantizedTensor,
16+
get_tensor_impl_constructor,
1417
register_layout,
1518
)
1619
from torchao.dtypes.affine_quantized_tensor_ops import (
@@ -19,6 +22,13 @@
1922
from torchao.dtypes.utils import AQTTensorImpl, Layout
2023
from torchao.quantization.quant_primitives import (
2124
ZeroPointDomain,
25+
MappingType,
26+
choose_qparams_affine,
27+
quantize_affine,
28+
)
29+
30+
from torchao.utils import (
31+
TORCH_VERSION_AT_LEAST_2_6,
2232
)
2333

2434
logger = logging.getLogger(__name__)
@@ -31,17 +41,33 @@
3141
handler.setFormatter(formatter)
3242
logger.addHandler(handler)
3343

44+
class Target(Enum):
45+
"""Enum that indicates the backend target"""
46+
47+
NATIVE = auto()
48+
ATEN = auto()
49+
50+
def target_from_str(target: str) -> Target:
51+
if target.lower() == "native":
52+
return Target.NATIVE
53+
elif target.lower() == "aten":
54+
return Target.ATEN
55+
else:
56+
raise ValueError(f"Invalid target: {target}")
3457

3558
class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
3659
bit_width: Optional[int]
3760
group_size: Optional[int]
3861
has_weight_zeros: Optional[bool]
62+
# The target platform for the layout, 'native' or 'aten'
63+
target: Optional[Target]
3964

4065
def __init__(
4166
self,
4267
bit_width: Optional[int] = None,
4368
group_size: Optional[int] = None,
4469
has_weight_zeros: Optional[bool] = None,
70+
target: Optional[str] = "native",
4571
):
4672
if bit_width is not None:
4773
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
@@ -51,6 +77,7 @@ def __init__(
5177
self.bit_width = bit_width
5278
self.group_size = group_size
5379
self.has_weight_zeros = has_weight_zeros
80+
self.target = target_from_str(target)
5481

5582
if not self.has_params_set():
5683
assert (
@@ -60,13 +87,14 @@ def __init__(
6087
), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False"
6188

6289
def extra_repr(self):
63-
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}"
90+
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, target={self.target}"
6491

6592
def has_params_set(self) -> bool:
6693
return (
6794
(self.bit_width is not None)
6895
and (self.group_size is not None)
6996
and (self.has_weight_zeros is not None)
97+
and (self.target is not None)
7098
)
7199

72100

@@ -125,9 +153,11 @@ def from_plain(
125153
scale: torch.Tensor,
126154
zero_point: Optional[torch.Tensor],
127155
layout: Layout,
156+
bias: Optional[torch.Tensor] = None,
128157
):
129158
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
130159
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
160+
assert layout.target in {Target.NATIVE, Target.ATEN}, f"Unexpected target: {layout.target}"
131161

132162
# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
133163
# when AOTI supports int
@@ -136,6 +166,13 @@ def from_plain(
136166
n_tensor = torch.empty(0, n, dtype=torch.int8)
137167
k_tensor = torch.empty(0, k, dtype=torch.int8)
138168

169+
if layout.target == Target.ATEN:
170+
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
171+
int_data = int_data.add(8)
172+
int_data = (int_data[::,1::2] << 4 | int_data[::,::2] ).to(torch.uint8)
173+
packed_weight = torch.ops.aten._dyn_quant_pack_4bit_weight(int_data, scale, bias, layout.group_size, k, n)
174+
return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor)
175+
139176
if layout.has_weight_zeros:
140177
args = [
141178
int_data.to(torch.int8),
@@ -211,16 +248,13 @@ def __tensor_unflatten__(
211248
def _linear_check(input_tensor, weight_tensor, bias):
212249
layout = weight_tensor.tensor_impl.get_layout()
213250
return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and (
214-
bias is None
251+
bias is None or layout.target == Target.ATEN # Aten target allows bias
215252
)
216253

217254

218255
def _linear_impl(input_tensor, weight_tensor, bias):
219-
assert (
220-
bias is None
221-
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl"
222256

223-
def _impl_2d(input_tensor, weight_tensor):
257+
def _impl_2d_native(input_tensor, weight_tensor):
224258
assert input_tensor.dim() == 2
225259
assert weight_tensor.dim() == 2
226260

@@ -255,6 +289,31 @@ def _impl_2d(input_tensor, weight_tensor):
255289
torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight"
256290
)(*args)
257291

292+
def _impl_2d_aten(input_tensor, weight_tensor):
293+
assert input_tensor.dim() == 2
294+
assert weight_tensor.dim() == 2
295+
296+
m, k = input_tensor.shape
297+
n, k_ = weight_tensor.shape
298+
assert k_ == k
299+
group_size = weight_tensor.tensor_impl.get_layout().group_size
300+
packed_weight = weight_tensor.tensor_impl.packed_weight
301+
return torch.ops.aten._dyn_quant_matmul_4bit(
302+
input_tensor, packed_weight, group_size, k_, n)
303+
304+
target = weight_tensor.tensor_impl.get_layout().target
305+
306+
if target == Target.ATEN:
307+
assert (
308+
TORCH_VERSION_AT_LEAST_2_6 == 1
309+
), "Target.ATEN requires torch >= 2.6.0"
310+
_impl_2d = _impl_2d_aten
311+
elif target == Target.NATIVE:
312+
_impl_2d = _impl_2d_native
313+
assert (
314+
bias is None
315+
), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl with target 'native' "
316+
258317
if input_tensor.dim() == 2:
259318
return _impl_2d(input_tensor, weight_tensor)
260319

@@ -268,8 +327,82 @@ def _impl_2d(input_tensor, weight_tensor):
268327
res = res.reshape(*lead_shape, m, n)
269328
return res
270329

271-
272330
register_aqt_quantized_linear_dispatch(
273331
_linear_check,
274332
_linear_impl,
275333
)
334+
335+
336+
class PackedLinearInt8DynamicActivationIntxWeightAtenTensor(AffineQuantizedTensor):
337+
"""
338+
PackedLinearInt8DynamicActivationIntxWeightAtenTensor quantized tensor subclass which inherits AffineQuantizedTensor class.
339+
"""
340+
341+
@classmethod
342+
def from_hp_to_intx(
343+
cls,
344+
input_float: torch.Tensor,
345+
mapping_type: MappingType,
346+
block_size: Tuple[int, ...],
347+
target_dtype: torch.dtype,
348+
quant_min: Optional[int] = None,
349+
quant_max: Optional[int] = None,
350+
eps: Optional[float] = None,
351+
scale_dtype: Optional[torch.dtype] = None,
352+
zero_point_dtype: Optional[torch.dtype] = None,
353+
preserve_zero: bool = True,
354+
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
355+
_layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(),
356+
use_hqq: bool = False,
357+
bias: Optional[torch.Tensor] = None
358+
):
359+
assert use_hqq == False, f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
360+
assert isinstance(
361+
_layout, PackedLinearInt8DynamicActivationIntxWeightLayout), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided {_layout}"
362+
assert _layout.target == Target.ATEN, f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
363+
original_shape = input_float.shape
364+
input_float = _layout.pre_process(input_float)
365+
366+
scale, zero_point = choose_qparams_affine(
367+
input_float,
368+
mapping_type,
369+
block_size,
370+
target_dtype,
371+
quant_min,
372+
quant_max,
373+
eps,
374+
scale_dtype,
375+
zero_point_dtype,
376+
preserve_zero,
377+
zero_point_domain,
378+
)
379+
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
380+
# TODO should probably consolidate ZeroPointDomain.NONE and None
381+
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
382+
zero_point = None
383+
data = quantize_affine(
384+
input_float,
385+
block_size,
386+
scale,
387+
zero_point,
388+
target_dtype,
389+
quant_min,
390+
quant_max,
391+
zero_point_domain,
392+
)
393+
# Note: output will be uint8 tensor for sub byte tensors for now
394+
395+
data = _layout.post_process(data)
396+
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
397+
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout, bias)
398+
return cls(
399+
tensor_impl,
400+
block_size,
401+
original_shape,
402+
quant_min,
403+
quant_max,
404+
zero_point_domain,
405+
dtype=input_float.dtype,
406+
)
407+
408+
to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor.from_hp_to_intx

0 commit comments

Comments
 (0)