Skip to content

Commit ef10f34

Browse files
authored
Add gguf q4_k quantization (#2001)
* Add gguf q4_k_s quantization Summary: Didn't implement the algorithm to choose_qparams from gguf, since it's complicated, e.g. https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L744 and https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L827C14-L827C28 but implemented a simple choose_qparams that can fit the gguf format: Q4_K: w = q * block_scale(6-bit) + block_min(6-bit) Test Plan: python test/prototype/test_gguf_quant.py Reviewers: Subscribers: Tasks: Tags: * fix * test with phi4 * pre-commit run * update * run precommit * format
1 parent 5802d2d commit ef10f34

File tree

7 files changed

+615
-1
lines changed

7 files changed

+615
-1
lines changed

test/prototype/test_gguf_quant.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchao.prototype.quantization.gguf import (
12+
GGUFQuantizedTensor,
13+
GGUFWeightOnlyConfig,
14+
)
15+
from torchao.quantization import quantize_
16+
from torchao.quantization.quant_primitives import choose_qparams_gguf
17+
from torchao.quantization.utils import compute_error
18+
19+
20+
class TestGGUFQuantization(unittest.TestCase):
21+
def setUp(self):
22+
torch.manual_seed(123)
23+
self.input = torch.randn(2, 256, dtype=torch.float32)
24+
self.n_blocks_per_superblock = 8
25+
self.block_size = (1, 32)
26+
self.dtype = torch.uint4
27+
28+
def test_choose_qparams_gguf(self):
29+
(
30+
super_block_scale_scale,
31+
super_block_min_scale,
32+
quantized_block_scale,
33+
quantized_block_min,
34+
) = choose_qparams_gguf(self.input, self.block_size, self.dtype)
35+
36+
assert super_block_scale_scale.shape, (2, 8)
37+
assert super_block_min_scale.shape, (2, 8)
38+
assert quantized_block_scale.shape, (2, 32)
39+
40+
def test_gguf_quantized_tensor_from_float(self):
41+
gqt = GGUFQuantizedTensor.from_float(
42+
self.input,
43+
self.n_blocks_per_superblock,
44+
self.dtype,
45+
)
46+
47+
dequant = gqt.dequantize()
48+
49+
sqnr = compute_error(dequant, self.input)
50+
self.assertGreater(sqnr, 30)
51+
52+
def test_quantize_api(self):
53+
m = torch.nn.Sequential(torch.nn.Linear(256, 64))
54+
quantize_(m, GGUFWeightOnlyConfig())
55+
assert type(m[0].weight) == GGUFQuantizedTensor
56+
57+
58+
if __name__ == "__main__":
59+
unittest.main()

torchao/core/config.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
171171
return json.loads(json.dumps(config, cls=ConfigJSONEncoder))
172172

173173

174-
ALLOWED_AO_MODULES = {"torchao.quantization", "torchao.sparsity.sparse_api"}
174+
ALLOWED_AO_MODULES = {
175+
"torchao.quantization",
176+
"torchao.sparsity.sparse_api",
177+
"torchao.prototype.quantization",
178+
}
175179

176180

177181
def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .gguf import GGUFWeightOnlyConfig
2+
3+
__all__ = [
4+
"GGUFWeightOnlyConfig",
5+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .api import GGUFWeightOnlyConfig
2+
from .gguf_quantized_tensor import (
3+
GGUFQuantizedTensor,
4+
)
5+
6+
__all__ = [
7+
"GGUFQuantizedTensor",
8+
"GGUFWeightOnlyConfig",
9+
]
+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
9+
import torch
10+
11+
from torchao.core.config import AOBaseConfig
12+
from torchao.quantization.transform_module import register_quantize_module_handler
13+
14+
from .gguf_quantized_tensor import GGUFQuantizedTensor
15+
16+
__all__ = [
17+
"GGUFWeightOnlyConfig",
18+
]
19+
20+
21+
@dataclass
22+
class GGUFWeightOnlyConfig(AOBaseConfig):
23+
dtype: torch.dtype = torch.uint4
24+
n_blocks_per_superblock: int = 8
25+
26+
27+
@register_quantize_module_handler(GGUFWeightOnlyConfig)
28+
def _gguf_weight_only_transform(
29+
module: torch.nn.Module,
30+
config: GGUFWeightOnlyConfig,
31+
):
32+
"""
33+
Applies gguf weight-only quantization to linear layers.
34+
35+
Args:
36+
dtype: torch.uint1 to torch.uint8, torch.int32 supported.
37+
n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8
38+
it means we have blocks of 32 and 8 blocks in a superblock of 256 elements.
39+
Returns:
40+
Callable for quantization transformation.
41+
"""
42+
weight = module.weight
43+
if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0):
44+
return module
45+
46+
quantized_weight = GGUFQuantizedTensor.from_float(
47+
weight,
48+
n_blocks_per_superblock=config.n_blocks_per_superblock,
49+
target_dtype=config.dtype,
50+
)
51+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
52+
return module

0 commit comments

Comments
 (0)