Skip to content

Commit 07362ff

Browse files
authored
ChatGLM3-6B LoRA Fine-tuning Demo (#11450)
* ChatGLM3-6B LoRA Fine-tuning Demo * refine * refine * add 2-card deepspeed * refine format * add mpi4py and deepspeed install
1 parent e000ac9 commit 07362ff

8 files changed

+927
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# LoRA Fine-Tuning on ChatGLM3-6B with IPEX-LLM
2+
3+
This example ports [ChatGLM3-6B lora_finetune](https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb) demo to IPEX-LLM on [Intel Arc GPU](../../README.md).
4+
5+
### 1. Install
6+
7+
```bash
8+
conda create -n llm python=3.11
9+
conda activate llm
10+
pip install "jieba>=0.42.1"
11+
pip install "ruamel_yaml>=0.18.6"
12+
pip install "rouge_chinese>=1.0.3"
13+
pip install "jupyter>=1.0.0"
14+
pip install "datasets>=2.18.0"
15+
pip install "peft>=0.10.0"
16+
pip install typer
17+
pip install sentencepiece
18+
pip install nltk
19+
pip install "numpy<2.0.0"
20+
pip install "deepspeed==0.13.1"
21+
pip install "mpi4py>=3.1.5"
22+
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
23+
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
24+
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
25+
```
26+
27+
### 2. Configures OneAPI Environment Variables
28+
```bash
29+
source /opt/intel/oneapi/setvars.sh
30+
```
31+
32+
### 3. LoRA Fine-Tune on ChatGLM3-6B
33+
34+
First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script:
35+
36+
```bash
37+
python process_advertise_gen_dataset.py
38+
```
39+
40+
Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B.
41+
42+
#### 3.1. Fine-Tune with a Single Arc Card
43+
44+
Start the fine-tuning by:
45+
46+
```bash
47+
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh
48+
```
49+
50+
Then, you will get output are as below:
51+
52+
```bash
53+
2024-06-27 13:47:02,680 - root - INFO - intel_extension_for_pytorch auto imported
54+
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.47it/s]
55+
2024-06-27 13:47:03,794 - ipex_llm.transformers.utils - INFO - Converting the current model to bf16 format......
56+
[2024-06-27 13:47:04,105] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to xpu (auto detect)
57+
trainable params: 487,424 || all params: 6,244,071,424 || trainable%: 0.0078
58+
PeftModelForCausalLM(
59+
(base_model): LoraModel(
60+
(model): ChatGLMForConditionalGeneration(
61+
(transformer): ChatGLMModel(
62+
(embedding): Embedding(
63+
(word_embeddings): Embedding(65024, 4096)
64+
)
65+
(rotary_pos_emb): RotaryEmbedding()
66+
(encoder): GLMTransformer(
67+
(layers): ModuleList(
68+
(0-27): 28 x GLMBlock(
69+
(input_layernorm): RMSNorm()
70+
(self_attention): SelfAttention(
71+
(query_key_value): LoraLowBitLinear(
72+
(base_layer): BF16Linear(in_features=4096, out_features=4608, bias=True)
73+
(lora_dropout): ModuleDict(
74+
(default): Dropout(p=0.1, inplace=False)
75+
)
76+
(lora_A): ModuleDict(
77+
(default): Linear(in_features=4096, out_features=2, bias=False)
78+
)
79+
(lora_B): ModuleDict(
80+
(default): Linear(in_features=2, out_features=4608, bias=False)
81+
)
82+
(lora_embedding_A): ParameterDict()
83+
(lora_embedding_B): ParameterDict()
84+
(qa_pool): Identity()
85+
)
86+
(core_attention): CoreAttention(
87+
(attention_dropout): Dropout(p=0.0, inplace=False)
88+
)
89+
(dense): BF16Linear(in_features=4096, out_features=4096, bias=False)
90+
)
91+
(post_attention_layernorm): RMSNorm()
92+
(mlp): MLP(
93+
(dense_h_to_4h): BF16Linear(in_features=4096, out_features=27392, bias=False)
94+
(dense_4h_to_h): BF16Linear(in_features=13696, out_features=4096, bias=False)
95+
)
96+
)
97+
)
98+
(final_layernorm): RMSNorm()
99+
)
100+
(output_layer): BF16Linear(in_features=4096, out_features=65024, bias=False)
101+
)
102+
)
103+
)
104+
)
105+
--> Model
106+
107+
--> model has 0.487424M params
108+
109+
train_dataset: Dataset({
110+
features: ['input_ids', 'labels'],
111+
num_rows: 114599
112+
})
113+
val_dataset: Dataset({
114+
features: ['input_ids', 'output_ids'],
115+
num_rows: 1070
116+
})
117+
test_dataset: Dataset({
118+
features: ['input_ids', 'output_ids'],
119+
num_rows: 1070
120+
})
121+
--> Sanity check
122+
'[gMASK]': 64790 -> -100
123+
'sop': 64792 -> -100
124+
'<|user|>': 64795 -> -100
125+
'': 30910 -> -100
126+
'\n': 13 -> -100
127+
......
128+
129+
# Here it takes time to finish the whole fine-tuning
130+
131+
......
132+
133+
Training completed. Do not forget to share your model on huggingface.co/models =)
134+
135+
136+
{'train_runtime': xxxx.xxxx, 'train_samples_per_second': x.xxx, 'train_steps_per_second': x.xxx, 'train_loss': xx.xx, 'epoch': x.xx}
137+
100%|████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [xx:xx<00:00, x.xxit/s]
138+
***** Running Prediction *****
139+
Num examples = 1070
140+
Batch size = 4
141+
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [xx:xx<00:00, x.xxs/it]
142+
```
143+
144+
#### 3.2. Fine-Tune with 2 Arc Cards
145+
146+
Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by:
147+
148+
```bash
149+
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh
150+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"zero_optimization": {
3+
"stage": 2,
4+
"offload_optimizer": {
5+
"device": "cpu"
6+
},
7+
"contiguous_gradients": true,
8+
"overlap_comm": true
9+
},
10+
"bf16": {
11+
"enabled": true
12+
},
13+
"train_micro_batch_size_per_gpu": "auto",
14+
"gradient_accumulation_steps": "auto"
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/configs/lora.yaml
2+
data_config:
3+
train_file: train.json
4+
val_file: dev.json
5+
test_file: dev.json
6+
num_proc: 16
7+
max_input_length: 128
8+
max_output_length: 128
9+
training_args:
10+
# see `transformers.Seq2SeqTrainingArguments`
11+
output_dir: ./output
12+
max_steps: 3000
13+
# needed to be fit for the dataset
14+
learning_rate: 5e-5
15+
# settings for data loading
16+
per_device_train_batch_size: 1
17+
dataloader_num_workers: 16
18+
remove_unused_columns: false
19+
# settings for saving checkpoints
20+
save_strategy: steps
21+
save_steps: 500
22+
# settings for logging
23+
log_level: info
24+
logging_strategy: steps
25+
logging_steps: 10
26+
# settings for evaluation
27+
per_device_eval_batch_size: 4
28+
evaluation_strategy: steps
29+
eval_steps: 1000
30+
# settings for optimizer
31+
# adam_epsilon: 1e-6
32+
# uncomment the following line to detect nan or inf values
33+
# debug: underflow_overflow
34+
predict_with_generate: true
35+
# see `transformers.GenerationConfig`
36+
generation_config:
37+
max_new_tokens: 128
38+
# set your absolute deepspeed path here
39+
#deepspeed: ds_zero_2.json
40+
# set to true if train with cpu.
41+
use_cpu: false
42+
peft_config:
43+
peft_type: LORA
44+
task_type: CAUSAL_LM
45+
r: 2
46+
lora_alpha: 8
47+
lora_dropout: 0.1

0 commit comments

Comments
 (0)