Skip to content

Commit 23213ca

Browse files
AQLM support for LoRA (huggingface#1476)
* aqlm * Style and copied tests * aqlm import guadr * docs * correct model in tests * Update docs/source/developer_guides/quantization.md Co-authored-by: Benjamin Bossan <[email protected]> * Update docs/source/developer_guides/quantization.md Co-authored-by: Benjamin Bossan <[email protected]> * moved aqlm install and added >= * Removed `quant_linear_module` * AqlmLoraLinear * docs update * transformers version check --------- Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 2efc36c commit 23213ca

File tree

9 files changed

+235
-5
lines changed

9 files changed

+235
-5
lines changed

docker/peft-gpu/Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ RUN source activate peft && \
6262
git+https://github.com/huggingface/accelerate \
6363
peft[test]@git+https://github.com/huggingface/peft
6464

65+
# Add aqlm for quantization testing
66+
RUN source activate peft && \
67+
pip install aqlm[gpu]>=1.0.2
68+
6569
RUN source activate peft && \
6670
pip freeze | grep transformers
6771

docs/source/developer_guides/quantization.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Quantization represents data with fewer bits, making it a useful technique for r
2121
* optimizing which model weights are quantized with the [AWQ](https://hf.co/papers/2306.00978) algorithm
2222
* independently quantizing each row of a weight matrix with the [GPTQ](https://hf.co/papers/2210.17323) algorithm
2323
* quantizing to 8-bit and 4-bit precision with the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library
24+
* quantizing to as low as 2-bit precision with the [AQLM](https://arxiv.org/abs/2401.06118) algorithm
2425

2526
However, after a model is quantized it isn't typically further trained for downstream tasks because training can be unstable due to the lower precision of the weights and activations. But since PEFT methods only add *extra* trainable parameters, this allows you to train a quantized model with a PEFT adapter on top! Combining quantization with PEFT can be a good strategy for training even the largest models on a single GPU. For example, [QLoRA](https://hf.co/papers/2305.14314) is a method that quantizes a model to 4-bits and then trains it with LoRA. This method allows you to finetune a 65B parameter model on a single 48GB GPU!
2627

@@ -137,6 +138,27 @@ QLoRA adds trainable weights to all the linear layers in the transformer archite
137138
config = LoraConfig(target_modules="all-linear", ...)
138139
```
139140

141+
## AQLM quantizaion
142+
143+
Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and takes advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes. This allows it to compress models down to as low as 2-bit with considerably low accuracy losses.
144+
145+
Since the AQLM quantization process is computationally expensive, a use of prequantized models is recommended. A partial list of available models can be found in the official aqlm [repository](https://github.com/Vahe1994/AQLM).
146+
147+
The models support LoRA adapter tuning. To tune the quantized model you'll need to install the `aqlm` inference library: `pip install aqlm>=1.0.2`. Finetuned LoRA adapters shall be saved separately, as merging them with AQLM quantized weights is not possible.
148+
149+
```py
150+
quantized_model = AutoModelForCausalLM.from_pretrained(
151+
"BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch",
152+
torch_dtype="auto", device_map="auto", low_cpu_mem_usage=True,
153+
)
154+
155+
peft_config = LoraConfig(...)
156+
157+
quantized_model = get_peft_model(quantized_model, peft_config)
158+
```
159+
160+
You can refer to the [Google Colab](https://colab.research.google.com/drive/12GTp1FCj5_0SnnNQH18h_2XFh9vS_guX?usp=sharing) example for an overview of AQLM+LoRA finetuning.
161+
140162
## Next steps
141163

142164
If you're interested in learning more about quantization, the following may be helpful:

src/peft/import_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,9 @@ def is_torch_tpu_available(check_device=True):
6565
return False
6666

6767

68+
def is_aqlm_available():
69+
return importlib.util.find_spec("aqlm") is not None
70+
71+
6872
def is_auto_awq_available():
6973
return importlib.util.find_spec("awq") is not None

src/peft/tuners/lora/aqlm.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2024-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Optional
16+
17+
import torch
18+
19+
from peft.import_utils import is_aqlm_available
20+
from peft.tuners.lora.layer import LoraLayer
21+
from peft.tuners.tuners_utils import BaseTunerLayer
22+
23+
24+
if is_aqlm_available():
25+
from aqlm import QuantizedLinear
26+
27+
28+
class AqlmLoraLinear(torch.nn.Module, LoraLayer):
29+
def __init__(
30+
self,
31+
base_layer,
32+
adapter_name: str,
33+
r: int = 0,
34+
lora_alpha: int = 1,
35+
lora_dropout: float = 0.0,
36+
init_lora_weights: bool = True,
37+
use_rslora: bool = False,
38+
**kwargs,
39+
):
40+
super().__init__()
41+
LoraLayer.__init__(self, base_layer)
42+
43+
self._active_adapter = adapter_name
44+
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora)
45+
46+
def forward(self, x: torch.Tensor):
47+
# note: logic differs from default Linear because merging is not supported
48+
result = self.base_layer(x)
49+
50+
if self.disable_adapters:
51+
return result
52+
53+
for active_adapter in self.active_adapters:
54+
if active_adapter not in self.lora_A.keys():
55+
continue
56+
lora_A = self.lora_A[active_adapter]
57+
lora_B = self.lora_B[active_adapter]
58+
dropout = self.lora_dropout[active_adapter]
59+
scaling = self.scaling[active_adapter]
60+
61+
requires_conversion = not torch.is_autocast_enabled()
62+
if requires_conversion:
63+
expected_dtype = result.dtype
64+
x = x.to(lora_A.weight.dtype)
65+
66+
output = lora_B(lora_A(dropout(x)))
67+
if requires_conversion:
68+
output = output.to(expected_dtype)
69+
output = output * scaling
70+
result += output
71+
return result
72+
73+
def __repr__(self) -> str:
74+
rep = super().__repr__()
75+
return "lora." + rep
76+
77+
# TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102
78+
# def reset_lora_parameters(self, adapter_name):
79+
# if adapter_name in self.lora_A.keys():
80+
# torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
81+
# torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
82+
83+
84+
def dispatch_aqlm(
85+
target: torch.nn.Module,
86+
adapter_name: str,
87+
**kwargs: Any,
88+
) -> Optional[torch.nn.Module]:
89+
new_module = None
90+
91+
if isinstance(target, BaseTunerLayer):
92+
target_base_layer = target.get_base_layer()
93+
else:
94+
target_base_layer = target
95+
96+
if is_aqlm_available() and isinstance(target_base_layer, QuantizedLinear):
97+
new_module = AqlmLoraLinear(target, adapter_name, **kwargs)
98+
target.qweight = target_base_layer.codes
99+
100+
return new_module

src/peft/tuners/lora/layer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None:
6666
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
6767
# Megatron ColumnParallelLinear,RowParallelLinear
6868
in_features, out_features = base_layer.input_size, base_layer.output_size
69+
elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear":
70+
# AQLM QuantLinear
71+
in_features, out_features = base_layer.in_features, base_layer.out_features
6972
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
7073
# Awq layers
7174
in_features, out_features = base_layer.in_features, base_layer.out_features

src/peft/tuners/lora/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties
4040

41+
from .aqlm import dispatch_aqlm
4142
from .awq import dispatch_awq
4243
from .config import LoraConfig
4344
from .gptq import dispatch_gptq
@@ -157,7 +158,7 @@ def _create_and_replace(
157158
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
158159
}
159160

160-
quant_methods = ["gptq", "awq"]
161+
quant_methods = ["gptq", "aqlm", "awq"]
161162
for quant_method in quant_methods:
162163
quantization_config = get_quantization_config(self.model, method=quant_method)
163164
if quantization_config is not None:
@@ -247,7 +248,7 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs):
247248

248249
dispatchers.append(dispatch_bnb_4bit)
249250

250-
dispatchers.extend([dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default])
251+
dispatchers.extend([dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default])
251252

252253
new_module = None
253254
for dispatcher in dispatchers:

src/peft/utils/other.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,21 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad
9292
"""
9393
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
9494
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
95+
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
9596
if gradient_checkpointing_kwargs is None:
9697
gradient_checkpointing_kwargs = {}
9798

9899
for name, param in model.named_parameters():
99100
# freeze base model's layers
100101
param.requires_grad = False
101102

102-
if not is_gptq_quantized:
103+
if not is_gptq_quantized and not is_aqlm_quantized:
103104
# cast all non INT8 parameters to fp32
104105
for param in model.parameters():
105106
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
106107
param.data = param.data.to(torch.float32)
107108

108-
if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing:
109+
if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized) and use_gradient_checkpointing:
109110
# When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
110111
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
111112
# For backward compatibility

tests/test_gpu_examples.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import gc
15+
import importlib
1516
import os
1617
import tempfile
1718
import unittest
@@ -24,6 +25,7 @@
2425
from accelerate.test_utils.testing import run_command
2526
from accelerate.utils import patch_environment
2627
from datasets import Audio, DatasetDict, load_dataset
28+
from packaging import version
2729
from parameterized import parameterized
2830
from transformers import (
2931
AutoModelForCausalLM,
@@ -53,6 +55,7 @@
5355
from peft.utils import SAFETENSORS_WEIGHTS_NAME
5456

5557
from .testing_utils import (
58+
require_aqlm,
5659
require_auto_awq,
5760
require_auto_gptq,
5861
require_bitsandbytes,
@@ -1383,6 +1386,91 @@ def test_model_loaded_in_float16_working(self):
13831386
trainer.train()
13841387

13851388

1389+
@require_torch_gpu
1390+
@require_aqlm
1391+
@unittest.skipUnless(
1392+
version.parse(importlib.metadata.version("transformers")) >= version.parse("4.38.0"),
1393+
"test requires `transformers>=4.38.0`",
1394+
)
1395+
class PeftAqlmGPUTests(unittest.TestCase):
1396+
r"""
1397+
AQLM + peft tests
1398+
"""
1399+
1400+
def setUp(self):
1401+
self.causal_lm_model_id = "BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf"
1402+
self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
1403+
1404+
def tearDown(self):
1405+
r"""
1406+
Efficient mechanism to free GPU memory after each test. Based on
1407+
https://github.com/huggingface/transformers/issues/21094
1408+
"""
1409+
gc.collect()
1410+
torch.cuda.empty_cache()
1411+
1412+
def _check_inference_finite(self, model, batch):
1413+
# try inference without Trainer class
1414+
training = model.training
1415+
model.eval()
1416+
output = model(**batch.to(model.device))
1417+
assert torch.isfinite(output.logits).all()
1418+
model.train(training)
1419+
1420+
@pytest.mark.single_gpu_tests
1421+
def test_causal_lm_training_aqlm(self):
1422+
r"""
1423+
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
1424+
correctly.
1425+
"""
1426+
with tempfile.TemporaryDirectory() as tmp_dir:
1427+
model = AutoModelForCausalLM.from_pretrained(
1428+
self.causal_lm_model_id,
1429+
device_map="cuda",
1430+
torch_dtype="auto",
1431+
)
1432+
1433+
model = prepare_model_for_kbit_training(model)
1434+
config = LoraConfig(
1435+
r=16,
1436+
lora_alpha=32,
1437+
target_modules=["q_proj", "v_proj"],
1438+
lora_dropout=0.05,
1439+
bias="none",
1440+
task_type="CAUSAL_LM",
1441+
)
1442+
model = get_peft_model(model, config)
1443+
1444+
data = load_dataset("ybelkada/english_quotes_copy")
1445+
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
1446+
1447+
trainer = Trainer(
1448+
model=model,
1449+
train_dataset=data["train"],
1450+
args=TrainingArguments(
1451+
per_device_train_batch_size=4,
1452+
gradient_accumulation_steps=4,
1453+
warmup_steps=2,
1454+
max_steps=3,
1455+
learning_rate=2e-4,
1456+
logging_steps=1,
1457+
output_dir=tmp_dir,
1458+
fp16=True,
1459+
),
1460+
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
1461+
)
1462+
model.config.use_cache = False
1463+
trainer.train()
1464+
1465+
model.cpu().save_pretrained(tmp_dir)
1466+
1467+
assert "adapter_config.json" in os.listdir(tmp_dir)
1468+
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
1469+
1470+
# assert loss is not None
1471+
assert trainer.state.log_history[-1]["train_loss"] is not None
1472+
1473+
13861474
@require_torch_gpu
13871475
@require_auto_awq
13881476
class PeftAwqGPUTests(unittest.TestCase):

tests/testing_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919
import torch
2020

21-
from peft.import_utils import is_auto_awq_available, is_auto_gptq_available, is_optimum_available
21+
from peft.import_utils import is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, is_optimum_available
2222

2323

2424
def require_torch_gpu(test_case):
@@ -61,6 +61,13 @@ def require_auto_gptq(test_case):
6161
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
6262

6363

64+
def require_aqlm(test_case):
65+
"""
66+
Decorator marking a test that requires aqlm. These tests are skipped when aqlm isn't installed.
67+
"""
68+
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
69+
70+
6471
def require_auto_awq(test_case):
6572
"""
6673
Decorator marking a test that requires auto-awq. These tests are skipped when auto-awq isn't installed.

0 commit comments

Comments
 (0)