Skip to content

Commit b2629b6

Browse files
authored
Awq 4 bit quantization support (#2508)
* add awq linear from AutoAWQ and/or llm-awq * add generic converter for llama-like models from HF with or without awq quantization
1 parent 3d4c8de commit b2629b6

File tree

6 files changed

+582
-16
lines changed

6 files changed

+582
-16
lines changed

README.md

+16-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Install `OpenNMT-py` from `pip`:
6363
pip install OpenNMT-py
6464
```
6565

66-
or from the sources:
66+
or from the source:
6767
```bash
6868
git clone https://github.com/OpenNMT/OpenNMT-py.git
6969
cd OpenNMT-py
@@ -107,6 +107,21 @@ When using `max_relative_positions > 0` or Alibi `max_relative_positions=-2` Ope
107107

108108
flash attention and `F.scaled_dot_product_attention` are a bit faster and saves some GPU memory.
109109

110+
111+
AWQ:
112+
113+
If you want to run inference or quantize an AWQ model you will need llm-awq and/or AutoAWQ.
114+
115+
For [llm-awq](https://github.com/mit-han-lab/llm-awq):
116+
git clone https://github.com/mit-han-lab/llm-awq
117+
cd llm-awq
118+
pip install -e .
119+
cd ..
120+
121+
For [AutoAWQ](https://github.com/casper-hansen/AutoAWQ):
122+
pip install autoawq
123+
124+
110125
## Documentation & FAQs
111126

112127
[Full HTML Documentation](https://opennmt.net/OpenNMT-py/quickstart.html)

onmt/decoders/transformer.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -899,15 +899,16 @@ def forward(self, tgt, enc_out=None, step=None, **kwargs):
899899

900900
def _init_cache(self, tgt=None):
901901
for layer in self.transformer_layers:
902-
if isinstance(layer.self_attn, AverageAttention):
903-
raise NotImplementedError
904-
else:
905-
layer.self_attn.layer_cache = (
906-
True,
907-
{
908-
"keys": torch.tensor([], device=tgt.device),
909-
"values": torch.tensor([], device=tgt.device),
910-
},
911-
)
912-
if hasattr(layer.self_attn, "rope"):
913-
layer.self_attn.rope = layer.self_attn.rope.to(tgt.device)
902+
if hasattr(layer, "self_attn"):
903+
if isinstance(layer.self_attn, AverageAttention):
904+
raise NotImplementedError
905+
else:
906+
layer.self_attn.layer_cache = (
907+
True,
908+
{
909+
"keys": torch.tensor([], device=tgt.device),
910+
"values": torch.tensor([], device=tgt.device),
911+
},
912+
)
913+
if hasattr(layer.self_attn, "rope"):
914+
layer.self_attn.rope = layer.self_attn.rope.to(tgt.device)

onmt/model_builder.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,13 @@ def load_test_model(opt, device_id=0, model_path=None):
9595

9696
model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
9797

98-
model_opt.quant_layers = opt.quant_layers
99-
model_opt.quant_type = opt.quant_type
98+
if hasattr(model_opt, "quant_type") and model_opt.quant_type not in [
99+
"llm_awq",
100+
"aawq_gemm",
101+
"aawq_gemv",
102+
]:
103+
model_opt.quant_layers = opt.quant_layers
104+
model_opt.quant_type = opt.quant_type
100105

101106
if opt.world_size > 1 and opt.parallel_mode == "tensor_parallel":
102107
model_opt.world_size = opt.world_size
@@ -304,6 +309,21 @@ def build_base_model(model_opt, vocabs):
304309
model = replace_bnb_linear(
305310
model, module_to_convert=nonlora_to_quant, q_type=model_opt.quant_type
306311
)
312+
elif model_opt.quant_type in ["llm_awq", "aawq_gemm", "aawq_gemv"]:
313+
logger.info(
314+
"%s compression of layer %s" % (model_opt.quant_type, nonlora_to_quant)
315+
)
316+
try:
317+
from onmt.modules.awq_linear import replace_awq_linear
318+
except ImportError:
319+
raise ImportError("Install llm-awq/AutoAWQ to use awq quantized model")
320+
model = replace_awq_linear(
321+
model,
322+
module_to_convert=nonlora_to_quant,
323+
w_bit=model_opt.w_bit,
324+
group_size=model_opt.group_size,
325+
q_type=model_opt.quant_type,
326+
)
307327
else:
308328
logger.info("compression type %s not supported." % model_opt.quant_type)
309329

onmt/modules/awq_linear.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch.nn as nn
2+
3+
4+
def replace_awq_linear(
5+
model, module_to_convert=[], w_bit=4, group_size=128, q_type="llm_awq"
6+
):
7+
if q_type == "llm_awq":
8+
try:
9+
from awq.quantize.qmodule import WQLinear
10+
except ImportError:
11+
raise ImportError("Install llm-awq to use awq")
12+
AWQLin = WQLinear
13+
elif q_type in ["aawq_gemm", "aawq_gemv"]:
14+
try:
15+
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
16+
except ImportError:
17+
raise ImportError("Install AutoAWQ to use awq")
18+
if q_type == "aawq_gemm":
19+
AWQLin = WQLinear_GEMM
20+
else:
21+
AWQLin = WQLinear_GEMV
22+
else:
23+
raise ValueError("No Awq framework for this value")
24+
25+
for name, module in model.named_children():
26+
if len(list(module.children())) > 0:
27+
replace_awq_linear(module, module_to_convert, w_bit, group_size, q_type)
28+
29+
if isinstance(module, nn.Linear) and name in module_to_convert:
30+
model._modules[name] = AWQLin(
31+
w_bit=w_bit,
32+
group_size=group_size,
33+
in_features=module.in_features,
34+
out_features=module.out_features,
35+
bias=module.bias is not None,
36+
dev=module.weight.device,
37+
)
38+
return model

onmt/opts.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1565,10 +1565,26 @@ def _add_quant_opts(parser):
15651565
"--quant_type",
15661566
"-quant_type",
15671567
default="bnb_8bit",
1568-
choices=["bnb_8bit", "bnb_FP4", "bnb_NF4"],
1568+
choices=["bnb_8bit", "bnb_FP4", "bnb_NF4", "llm_awq", "aawq_gemm", "aawq_gemv"],
15691569
type=str,
15701570
help="Type of compression.",
15711571
)
1572+
group.add(
1573+
"--w_bit",
1574+
"-w_bit",
1575+
type=int,
1576+
default=4,
1577+
choices=[4],
1578+
help="W_bit quantization.",
1579+
)
1580+
group.add(
1581+
"--group_size",
1582+
"-group_size",
1583+
default=128,
1584+
choices=[128],
1585+
type=int,
1586+
help="group size quantization.",
1587+
)
15721588

15731589

15741590
def train_opts(parser):

0 commit comments

Comments
 (0)