Skip to content

Commit 29400c6

Browse files
committed
Initial ParetoQ commit
This project contains the training code of ParetoQ introduced in: "ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization" (https://arxiv.org/abs/2502.02631). All code is written by @liuzechun and @zxdmike and migrated from https://github.com/facebookresearch/ParetoQ. ParetoQ is the first unified framework that facilitates rigorous comparisons across 1-bit, 1.58-bit, 2-bit, 3-bit, and 4-bit quantization settings. By optimizing training schemes and refining quantization functions, ParetoQ surpasses all previous methods tailored to specific bit widths. Specifically, the 1.58-bit ParetoQ LLaMA-3 8B model reduces the performance gap to full precision by relatively 37.8% compared to the 1-bit Era’s 1.58-bit LLaMA-3 8B model, while using only 30% of the training tokens.
1 parent 8c81863 commit 29400c6

15 files changed

+2268
-0
lines changed

ruff.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ lint.ignore = ["E731"]
77
# Exclude third-party modules
88
exclude = [
99
"third_party/*",
10+
"ao/prototype/pareto_q/*",
1011
]
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# coding=utf-8
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
torchrun --nnodes=1 --nproc_per_node=1 train.py \
9+
--local_dir "/tmp/llama/" \
10+
--input_model_filename "meta-llama/Llama-3.2-1B" \
11+
--output_model_filename "1B-finetuned" \
12+
--train_data_local_path "/tmp/train.jsonl" \
13+
--do_train True \
14+
--do_eval False \
15+
--model_max_length 2048 \
16+
--fp16 False \
17+
--bf16 True \
18+
--log_on_each_node False \
19+
--logging_dir /tmp/output/runs/current \
20+
--num_train_epochs 1 \
21+
--per_device_train_batch_size 2 \
22+
--per_device_eval_batch_size 1 \
23+
--gradient_accumulation_steps 1 \
24+
--evaluation_strategy "no" \
25+
--save_strategy "steps" \
26+
--save_steps 2000 \
27+
--report_to "tensorboard" \
28+
--save_total_limit 1 \
29+
--learning_rate 2e-5 \
30+
--weight_decay 0. \
31+
--warmup_ratio 0. \
32+
--lr_scheduler_type "cosine" \
33+
--logging_steps 1 \
34+
--tf32 False \
35+
--gradient_checkpointing False \
36+
--qat True \
37+
--w_bits 4 \
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# coding=utf-8
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
CUDA_VISIBLE_DEVICES=0 torchrun --nnodes=1 --nproc_per_node=1 train.py \
9+
--local_dir "/tmp/llama/" \
10+
--input_model_filename "/tmp/llama_1B/llama_1B_bit1" \
11+
--output_model_filename "1B-finetuned" \
12+
--train_data_local_path "/tmp/train.jsonl" \
13+
--eval_data_local_path "/tmp/wikitext-2/test.jsonl" \
14+
--do_train False \
15+
--do_eval True \
16+
--model_max_length 2048 \
17+
--fp16 False \
18+
--bf16 True \
19+
--log_on_each_node False \
20+
--logging_dir /tmp/output/runs/current \
21+
--num_train_epochs 1 \
22+
--per_device_train_batch_size 2 \
23+
--per_device_eval_batch_size 4 \
24+
--gradient_accumulation_steps 1 \
25+
--evaluation_strategy "no" \
26+
--save_strategy "steps" \
27+
--save_steps 2000 \
28+
--report_to "tensorboard" \
29+
--save_total_limit 1 \
30+
--learning_rate 2e-5 \
31+
--weight_decay 0. \
32+
--warmup_ratio 0. \
33+
--lr_scheduler_type "cosine" \
34+
--logging_steps 1 \
35+
--tf32 False \
36+
--gradient_checkpointing False \
37+
--qat True \
38+
--w_bits 1 \
39+
--contain_weight_clip_val True \

torchao/prototype/pareto_q/README.md

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# ParetoQ
2+
3+
4+
This repository contains the training code of ParetoQ introduced in our work: "[ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization](https://arxiv.org/abs/2502.02631)"
5+
6+
In this work, we present ParetoQ, the first unified framework that facilitates rigorous comparisons across 1-bit, 1.58-bit, 2-bit, 3-bit, and 4-bit quantization settings. By optimizing training schemes and refining quantization functions, ParetoQ surpasses all previous methods tailored to specific bit widths. Specifically, the 1.58-bit ParetoQ LLaMA-3 8B model reduces the performance gap to full precision by relatively 37.8% compared to the 1-bit Era’s 1.58-bit LLaMA-3 8B model, while using only 30% of the training tokens.
7+
8+
<div align=center>
9+
<img width=50% src="./main_result_ternary.jpg"/>
10+
</div>
11+
12+
<div align=center>
13+
<img width=100% src="./main_result_234bit.jpg"/>
14+
</div>
15+
16+
With the SoTA points obtained through ParetoQ, we are able to improve the scaling law analysis. Figure (a) (b) demonstrates that sub-4-bit quantization, including binary, ternary, 2-bit, and 3-bit, often outperform 4-bit quantization. Notably, 2-bit and ternary models reside on the Pareto frontier. When considering hardware-friendliness and real-time speed, we generally recommend exploring 2-bit quantization for on-device applications.
17+
18+
<div align=center>
19+
<img width=100% src="./main_result_scaling_law.jpg"/>
20+
</div>
21+
## Citation
22+
23+
If you find our code useful for your research, please consider citing:
24+
25+
@article{liu2025paretoq,
26+
title={ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization},
27+
author={Liu, Zechun and Zhao, Changsheng and Huang, Hanxian and Chen, Sijia and Zhang, Jing and Zhao, Jiawei and Roy, Scott and Jin, Lisa and Xiong, Yunyang and Shi, Yangyang and others},
28+
journal={arXiv preprint arXiv:2502.02631},
29+
year={2025}
30+
}
31+
32+
## Run
33+
34+
### 1. Requirements:
35+
* python 3.11
36+
* pip3 install torch
37+
* pip install -r requirement.txt
38+
39+
### 2. Steps to run:
40+
* Specify the data path and the pre-trained full-precision model path in run_train.sh file
41+
* Run `bash 1_run_train.sh $w_bit` E.g. `bash 1_run_train.sh 2` for 2-bit weight quantization.
42+
43+
## Comparison to SoTA Ternary LLM methods
44+
The results reported in the paper is run with the internal LLaMA codebase in Meta. We reproduced our experiments with huggingface codebase and released code here. The results are close to those in the paper.
45+
46+
| Method | #Params | Arc-e | Arc-c | Boolq | Piqa | Siqa | HellaSwag | Obqa | WinoGrande | Avg. | Wiki |
47+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
48+
| RTN | 600M | 26.2 | 24.6 | 62.2 | 49.5 | 36.3 | 26.1 | 27.1 | 48.8 | 37.6 | 6.60E+05 |
49+
| LLM-QAT | 600M | 34.0 | 23.0 | 59.4 | 53.6 | 38.9 | 28.7 | 32.3 | 51.4 | 40.2 | 71.7 |
50+
| 1-bit era | 700M | 49.5 | 29.0 | 59.2 | 67.5 | 43.6 | 43.2 | 38.9 | 53.5 | 48.1 | 17.3 |
51+
| Spectra | 560M | 50.2 | 21.0 | 57.3 | 67.5 | -- | 33.8 | -- | 53.1 | | -- |
52+
| **ParetoQ** | **600M** | **65.5** | **43.8** | **62.3** | **70.6** | **44.7** | **51.3** | **47.1** | **58.8** | **55.5** | **11.4** |
53+
| RTN | 1B | 25.7 | 24.8 | 37.8 | 49.3 | 37.1 | 26.2 | 25.2 | 50.2 | 34.5 | 1.40E+05 |
54+
| LLM-QAT | 1B | 36.0 | 26.2 | 47.7 | 55.1 | 39.7 | 31.3 | 33.5 | 49.6 | 39.9 | 56.9 |
55+
| 1-bit era | 1.3B | 52.4 | 34.1 | 61.9 | 69.1 | 44.7 | 47.4 | 41.1 | 55.3 | 50.8 | 23.6 |
56+
| Spectra | 1.1B | 56.3 | 24.6 | 59.1 | 69.3 | -- | 38.8 | -- | 55.5 | | -- |
57+
| **ParetoQ** | **1B** | **68.5** | **47.6** | **62.8** | **72.1** | **45.3** | **57.4** | **52.9** | **61.3** | **58.5** | **10.0** |
58+
| RTN | 3B | 26.9 | 23.6 | 62.2 | 51.3 | 37.6 | 26.4 | 27.0 | 49.3 | 38.0 | 4.40E+05 |
59+
| LLM-QAT | 3B | 44.5 | 30.7 | 62.1 | 62.7 | 41.0 | 43.4 | 35.0 | 50.6 | 46.3 | 6.50E+02 |
60+
| 1-bit era | 3B | 58.7 | 37.2 | 61.3 | 71.3 | 45.2 | 56.0 | 45.8 | 60.3 | 54.5 | 265.6 |
61+
| Spectra | 3.9B | 66.0 | 31.9 | 66.5 | 74.4 | -- | 48.3 | -- | 62.1 | | -- |
62+
| **ParetoQ** | **3B** | **71.5** | **48.6** | **68.2** | **75.5** | **46.4** | **67.9** | **54.3** | **63.1** | **61.9** | **9.9** |
63+
64+
More results for other bit widths can be found in the [paper](https://arxiv.org/abs/2502.02631).
65+
66+
## Acknowledgement
67+
68+
This code is partially based on HuggingFace transformer repo.
69+
70+
## Contact
71+
72+
Zechun Liu, Reality Labs, Meta Inc (zechunliu at meta dot com)
73+
74+
Changsheng Zhao, Reality Labs, Meta Inc (cszhao at meta dot com)
75+
76+
## License
77+
78+
BiT is CC-BY-NC 4.0 licensed as of now.
Loading
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
# coding=utf-8
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
#
8+
# coding=utf-8
9+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
10+
#
11+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
12+
# and OPT implementations in this library. It has been modified from its
13+
# original forms to accommodate minor architectural differences compared
14+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
15+
#
16+
# Licensed under the Apache License, Version 2.0 (the "License");
17+
# you may not use this file except in compliance with the License.
18+
# You may obtain a copy of the License at
19+
#
20+
# http://www.apache.org/licenses/LICENSE-2.0
21+
#
22+
# Unless required by applicable law or agreed to in writing, software
23+
# distributed under the License is distributed on an "AS IS" BASIS,
24+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25+
# See the License for the specific language governing permissions and
26+
# limitations under the License.
27+
"""LLaMA model configuration"""
28+
29+
from transformers.configuration_utils import PretrainedConfig
30+
from transformers.modeling_rope_utils import rope_config_validation
31+
32+
33+
class LlamaConfig(PretrainedConfig):
34+
r"""
35+
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
36+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
37+
defaults will yield a similar configuration to that of the LLaMA-7B.
38+
39+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40+
documentation from [`PretrainedConfig`] for more information.
41+
42+
43+
Args:
44+
vocab_size (`int`, *optional*, defaults to 32000):
45+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
46+
`inputs_ids` passed when calling [`LlamaModel`]
47+
hidden_size (`int`, *optional*, defaults to 4096):
48+
Dimension of the hidden representations.
49+
intermediate_size (`int`, *optional*, defaults to 11008):
50+
Dimension of the MLP representations.
51+
num_hidden_layers (`int`, *optional*, defaults to 32):
52+
Number of hidden layers in the Transformer decoder.
53+
num_attention_heads (`int`, *optional*, defaults to 32):
54+
Number of attention heads for each attention layer in the Transformer decoder.
55+
num_key_value_heads (`int`, *optional*):
56+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
57+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
58+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
59+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
60+
by meanpooling all the original heads within that group. For more details checkout [this
61+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
62+
`num_attention_heads`.
63+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
64+
The non-linear activation function (function or string) in the decoder.
65+
max_position_embeddings (`int`, *optional*, defaults to 2048):
66+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
67+
Llama 2 up to 4096, CodeLlama up to 16384.
68+
initializer_range (`float`, *optional*, defaults to 0.02):
69+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
70+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
71+
The epsilon used by the rms normalization layers.
72+
use_cache (`bool`, *optional*, defaults to `True`):
73+
Whether or not the model should return the last key/values attentions (not used by all models). Only
74+
relevant if `config.is_decoder=True`.
75+
pad_token_id (`int`, *optional*):
76+
Padding token id.
77+
bos_token_id (`int`, *optional*, defaults to 1):
78+
Beginning of stream token id.
79+
eos_token_id (`int`, *optional*, defaults to 2):
80+
End of stream token id.
81+
pretraining_tp (`int`, *optional*, defaults to 1):
82+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
83+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
84+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
85+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
86+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
87+
Whether to tie weight embeddings
88+
rope_theta (`float`, *optional*, defaults to 10000.0):
89+
The base period of the RoPE embeddings.
90+
rope_scaling (`Dict`, *optional*):
91+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
92+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
93+
accordingly.
94+
Expected contents:
95+
`rope_type` (`str`):
96+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
97+
'llama3'], with 'default' being the original RoPE implementation.
98+
`factor` (`float`, *optional*):
99+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
100+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
101+
original maximum pre-trained length.
102+
`original_max_position_embeddings` (`int`, *optional*):
103+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
104+
pretraining.
105+
`attention_factor` (`float`, *optional*):
106+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
107+
computation. If unspecified, it defaults to value recommended by the implementation, using the
108+
`factor` field to infer the suggested value.
109+
`beta_fast` (`float`, *optional*):
110+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
111+
ramp function. If unspecified, it defaults to 32.
112+
`beta_slow` (`float`, *optional*):
113+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
114+
ramp function. If unspecified, it defaults to 1.
115+
`short_factor` (`List[float]`, *optional*):
116+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
117+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
118+
size divided by the number of attention heads divided by 2
119+
`long_factor` (`List[float]`, *optional*):
120+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
121+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
122+
size divided by the number of attention heads divided by 2
123+
`low_freq_factor` (`float`, *optional*):
124+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
125+
`high_freq_factor` (`float`, *optional*):
126+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
127+
attention_bias (`bool`, *optional*, defaults to `False`):
128+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
129+
attention_dropout (`float`, *optional*, defaults to 0.0):
130+
The dropout ratio for the attention probabilities.
131+
mlp_bias (`bool`, *optional*, defaults to `False`):
132+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
133+
head_dim (`int`, *optional*):
134+
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
135+
136+
```python
137+
>>> from transformers import LlamaModel, LlamaConfig
138+
139+
>>> # Initializing a LLaMA llama-7b style configuration
140+
>>> configuration = LlamaConfig()
141+
142+
>>> # Initializing a model from the llama-7b style configuration
143+
>>> model = LlamaModel(configuration)
144+
145+
>>> # Accessing the model configuration
146+
>>> configuration = model.config
147+
```"""
148+
149+
model_type = "llama"
150+
keys_to_ignore_at_inference = ["past_key_values"]
151+
# Default tensor parallel plan for base model `LlamaModel`
152+
base_model_tp_plan = {
153+
"layers.*.self_attn.q_proj": "colwise",
154+
"layers.*.self_attn.k_proj": "colwise",
155+
"layers.*.self_attn.v_proj": "colwise",
156+
"layers.*.self_attn.o_proj": "rowwise",
157+
"layers.*.mlp.gate_proj": "colwise",
158+
"layers.*.mlp.up_proj": "colwise",
159+
"layers.*.mlp.down_proj": "rowwise",
160+
}
161+
base_model_pp_plan = {
162+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
163+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
164+
"norm": (["hidden_states"], ["hidden_states"]),
165+
}
166+
167+
def __init__(
168+
self,
169+
vocab_size=32000,
170+
hidden_size=4096,
171+
intermediate_size=11008,
172+
num_hidden_layers=32,
173+
num_attention_heads=32,
174+
num_key_value_heads=None,
175+
hidden_act="silu",
176+
max_position_embeddings=2048,
177+
initializer_range=0.02,
178+
rms_norm_eps=1e-6,
179+
use_cache=True,
180+
pad_token_id=None,
181+
bos_token_id=1,
182+
eos_token_id=2,
183+
pretraining_tp=1,
184+
tie_word_embeddings=False,
185+
rope_theta=10000.0,
186+
rope_scaling=None,
187+
attention_bias=False,
188+
attention_dropout=0.0,
189+
mlp_bias=False,
190+
head_dim=None,
191+
w_bits=32,
192+
**kwargs,
193+
):
194+
self.vocab_size = vocab_size
195+
self.max_position_embeddings = max_position_embeddings
196+
self.hidden_size = hidden_size
197+
self.intermediate_size = intermediate_size
198+
self.num_hidden_layers = num_hidden_layers
199+
self.num_attention_heads = num_attention_heads
200+
201+
# for backward compatibility
202+
if num_key_value_heads is None:
203+
num_key_value_heads = num_attention_heads
204+
205+
self.num_key_value_heads = num_key_value_heads
206+
self.hidden_act = hidden_act
207+
self.initializer_range = initializer_range
208+
self.rms_norm_eps = rms_norm_eps
209+
self.pretraining_tp = pretraining_tp
210+
self.use_cache = use_cache
211+
self.rope_theta = rope_theta
212+
self.rope_scaling = rope_scaling
213+
self.attention_bias = attention_bias
214+
self.attention_dropout = attention_dropout
215+
self.mlp_bias = mlp_bias
216+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
217+
# Validate the correctness of rotary position embeddings parameters
218+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
219+
if self.rope_scaling is not None and "type" in self.rope_scaling:
220+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
221+
rope_config_validation(self)
222+
self.w_bits = w_bits
223+
224+
super().__init__(
225+
pad_token_id=pad_token_id,
226+
bos_token_id=bos_token_id,
227+
eos_token_id=eos_token_id,
228+
tie_word_embeddings=tie_word_embeddings,
229+
**kwargs,
230+
)
231+
232+
233+
__all__ = ["LlamaConfig"]

0 commit comments

Comments
 (0)