Skip to content

Commit 23db9bf

Browse files
authored
Move MarlinQQQTensor out of AQT (#1385)
1 parent 8a805d0 commit 23db9bf

File tree

6 files changed

+89
-60
lines changed

6 files changed

+89
-60
lines changed

Diff for: torchao/dtypes/README.md

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# README
2+
3+
## File Structure of the `dtypes` Folder
4+
5+
The `dtypes` folder contains several important files and subfolders that are organized as follows:
6+
7+
- **affine_quantized_tensor.py**: This is the main file, from which the subfolders `uintx` and `floatx` inherit. It contains the base tensor subclass `AffineQuantizedTensor` and code for layout and tensorImpl registration.
8+
9+
- **affine_quantized_tensor_ops.py**: This file defines all the overriden aten ops and different dispatch kernels related to affine quantized tensors.
10+
11+
- **utils.py**: A utility file that provides helper functions and common utilities used across different files in the `dtypes` folder.
12+
13+
- **nf4tensor.py**: This file is specific to the NF4 tensor implementation, and layouts.
14+
15+
### Subfolders
16+
17+
- **uintx**: A subfolder that contains layouts and tensor subclasses inheriting from `affine_quantized_tensor.py`. It is specialized for handling unsigned integer quantized tensors.
18+
19+
- **floatx**: Similar to `uintx`, this subfolder contains layouts and tensor subclasses that inherit from `affine_quantized_tensor.py`, but it is focused on floating-point quantized tensors.

Diff for: torchao/dtypes/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from . import affine_quantized_tensor_ops
22
from .affine_quantized_tensor import (
33
AffineQuantizedTensor,
4-
MarlinQQQTensor,
54
to_affine_quantized_floatx,
65
to_affine_quantized_floatx_static,
76
# experimental, will be merged into floatx in the future
87
to_affine_quantized_fpx,
98
to_affine_quantized_intx,
109
to_affine_quantized_intx_static,
11-
to_marlinqqq_quantized_intx,
1210
)
1311
from .floatx import (
1412
Float8Layout,
@@ -18,10 +16,12 @@
1816
BlockSparseLayout,
1917
Int4CPULayout,
2018
MarlinQQQLayout,
19+
MarlinQQQTensor,
2120
MarlinSparseLayout,
2221
SemiSparseLayout,
2322
TensorCoreTiledLayout,
2423
UintxLayout,
24+
to_marlinqqq_quantized_intx,
2525
)
2626
from .utils import (
2727
Layout,

Diff for: torchao/dtypes/affine_quantized_tensor.py

-56
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
choose_qparams_affine,
1717
choose_qparams_affine_floatx,
1818
choose_qparams_and_quantize_affine_hqq,
19-
choose_qparams_and_quantize_affine_qqq,
2019
dequantize_affine,
2120
dequantize_affine_floatx,
22-
dequantize_affine_qqq,
2321
quantize_affine,
2422
quantize_affine_floatx,
2523
)
@@ -33,14 +31,12 @@
3331

3432
__all__ = [
3533
"AffineQuantizedTensor",
36-
"MarlinQQQTensor",
3734
"register_layout",
3835
"to_affine_quantized_intx",
3936
"to_affine_quantized_floatx",
4037
"to_affine_quantized_intx_static",
4138
"to_affine_quantized_floatx_static",
4239
"to_affine_quantized_fpx",
43-
"to_marlinqqq_quantized_intx",
4440
]
4541

4642

@@ -459,57 +455,6 @@ def _apply_fn_to_data(self, fn):
459455
# 2 - we're given non-floats - quantizing long to int8 is crazy
460456

461457

462-
class MarlinQQQTensor(AffineQuantizedTensor):
463-
"""
464-
MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class.
465-
466-
To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization,
467-
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
468-
and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq
469-
"""
470-
471-
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
472-
if output_dtype is None:
473-
output_dtype = self.dtype
474-
475-
int_data, s_group, s_channel = self.tensor_impl.get_plain()
476-
nbits = int(math.log2(self.quant_max - self.quant_min + 1))
477-
group_size = max(self.block_size)
478-
return dequantize_affine_qqq(
479-
int_data, s_group, s_channel, nbits, group_size, output_dtype
480-
)
481-
482-
@classmethod
483-
def from_hp_to_intx(
484-
cls,
485-
input_float: torch.Tensor,
486-
block_size: Tuple[int, ...],
487-
quant_min: Optional[int] = None,
488-
quant_max: Optional[int] = None,
489-
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
490-
_layout: Optional[Layout] = None,
491-
):
492-
original_shape = input_float.shape
493-
input_float = _layout.pre_process(input_float)
494-
nbits = int(math.log2(quant_max - quant_min + 1))
495-
group_size = max(block_size)
496-
data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
497-
input_float, nbits, group_size
498-
)
499-
data = _layout.post_process(data)
500-
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
501-
tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout)
502-
return cls(
503-
tensor_impl,
504-
block_size,
505-
original_shape,
506-
quant_min,
507-
quant_max,
508-
zero_point_domain,
509-
dtype=input_float.dtype,
510-
)
511-
512-
513458
######################################################
514459
# Layout and TensorImpl Subclass Registration #
515460
######################################################
@@ -522,7 +467,6 @@ def from_hp_to_intx(
522467
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
523468
# experimental will be merged in to floatx
524469
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
525-
to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx
526470

527471
if TORCH_VERSION_AT_LEAST_2_5:
528472
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`

Diff for: torchao/dtypes/affine_quantized_tensor_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
_linear_int8_act_int8_weight_block_sparse_check,
2121
_linear_int8_act_int8_weight_block_sparse_impl,
2222
)
23-
from torchao.dtypes.uintx.marlin_qqq_layout import (
23+
from torchao.dtypes.uintx.marlin_qqq_tensor import (
2424
_linear_int8_act_int4_weight_marlin_qqq_check,
2525
_linear_int8_act_int4_weight_marlin_qqq_impl,
2626
)

Diff for: torchao/dtypes/uintx/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from .block_sparse_layout import (
22
BlockSparseLayout,
33
)
4-
from .marlin_qqq_layout import (
4+
from .marlin_qqq_tensor import (
55
MarlinQQQLayout,
6+
MarlinQQQTensor,
7+
to_marlinqqq_quantized_intx,
68
)
79
from .marlin_sparse_layout import (
810
MarlinSparseLayout,
@@ -26,4 +28,6 @@
2628
"TensorCoreTiledLayout",
2729
"Int4CPULayout",
2830
"MarlinQQQLayout",
31+
"MarlinQQQTensor",
32+
"to_marlinqqq_quantized_intx",
2933
]

Diff for: torchao/dtypes/uintx/marlin_qqq_layout.py renamed to torchao/dtypes/uintx/marlin_qqq_tensor.py

+62
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
2+
import math
23
from dataclasses import dataclass
4+
from typing import Optional, Tuple
35

46
import torch
57
from torch.utils._python_dispatch import (
@@ -8,18 +10,75 @@
810

911
from torchao.dtypes.affine_quantized_tensor import (
1012
AffineQuantizedTensor,
13+
get_tensor_impl_constructor,
1114
register_layout,
1215
)
1316
from torchao.dtypes.uintx.plain_layout import (
1417
_aqt_is_int8_reduced_range,
1518
)
1619
from torchao.dtypes.utils import AQTTensorImpl, Layout
20+
from torchao.quantization.quant_primitives import (
21+
ZeroPointDomain,
22+
choose_qparams_and_quantize_affine_qqq,
23+
dequantize_affine_qqq,
24+
)
1725

1826
logger = logging.getLogger(__name__)
1927

2028
aten = torch.ops.aten
2129

2230

31+
class MarlinQQQTensor(AffineQuantizedTensor):
32+
"""
33+
MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class.
34+
35+
To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization,
36+
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
37+
and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq
38+
"""
39+
40+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
41+
if output_dtype is None:
42+
output_dtype = self.dtype
43+
44+
int_data, s_group, s_channel = self.tensor_impl.get_plain()
45+
nbits = int(math.log2(self.quant_max - self.quant_min + 1))
46+
group_size = max(self.block_size)
47+
return dequantize_affine_qqq(
48+
int_data, s_group, s_channel, nbits, group_size, output_dtype
49+
)
50+
51+
@classmethod
52+
def from_hp_to_intx(
53+
cls,
54+
input_float: torch.Tensor,
55+
block_size: Tuple[int, ...],
56+
quant_min: Optional[int] = None,
57+
quant_max: Optional[int] = None,
58+
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
59+
_layout: Optional[Layout] = None,
60+
):
61+
original_shape = input_float.shape
62+
input_float = _layout.pre_process(input_float)
63+
nbits = int(math.log2(quant_max - quant_min + 1))
64+
group_size = max(block_size)
65+
data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
66+
input_float, nbits, group_size
67+
)
68+
data = _layout.post_process(data)
69+
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
70+
tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout)
71+
return cls(
72+
tensor_impl,
73+
block_size,
74+
original_shape,
75+
quant_min,
76+
quant_max,
77+
zero_point_domain,
78+
dtype=input_float.dtype,
79+
)
80+
81+
2382
@dataclass(frozen=True)
2483
class MarlinQQQLayout(Layout):
2584
pass
@@ -279,3 +338,6 @@ def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bi
279338
if bias is not None:
280339
out += bias.to(out.dtype)
281340
return out
341+
342+
343+
to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx

0 commit comments

Comments
 (0)