1
+ import math
1
2
import os
3
+ import signal
2
4
import sys
3
5
from pathlib import Path
4
6
7
+ import bitsandbytes as bnb
5
8
import fire
6
9
import torch
7
10
import transformers
8
11
import yaml
9
12
from attrdict import AttrDict
10
- from datasets import load_dataset , IterableDataset
13
+ from datasets import load_dataset , IterableDataset , Dataset
11
14
from peft import (
12
15
LoraConfig ,
13
16
get_peft_model ,
14
- prepare_model_for_int8_training ,
17
+ prepare_model_for_int8_training , get_peft_model_state_dict ,
15
18
)
19
+ from torch import nn
16
20
from transformers import AutoModelForCausalLM , AutoTokenizer
17
21
18
22
# add src to the pythonpath so we don't need to pip install this
23
+ from transformers .trainer_pt_utils import get_parameter_names
24
+
19
25
project_root = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' ))
20
26
src_dir = os .path .join (project_root , 'src' )
21
27
sys .path .insert (0 , src_dir )
22
28
23
- from axolotl .datasets import TokenizedPromptDataset
29
+ from axolotl .datasets import TokenizedPromptDataset , ConstantLengthDataset
24
30
from axolotl .prompt_tokenizers import AlpacaPromptTokenizingStrategy , ShareGPTPromptTokenizingStrategy , \
25
31
LLAMA_DEFAULT_PAD_TOKEN , GPTeacherPromptTokenizingStrategy
26
32
from axolotl .prompters import AlpacaPrompter , GPTeacherPrompter , ShareGPTPrompter
@@ -29,9 +35,9 @@ def setup_wandb_env_vars(cfg):
29
35
if len (cfg .wandb_project ) > 0 :
30
36
os .environ ["WANDB_PROJECT" ] = cfg .wandb_project
31
37
cfg .use_wandb = True
32
- if len (cfg .wandb_watch ) > 0 :
38
+ if cfg . wandb_watch and len (cfg .wandb_watch ) > 0 :
33
39
os .environ ["WANDB_WATCH" ] = cfg .wandb_watch
34
- if len (cfg .wandb_log_model ) > 0 :
40
+ if cfg . wandb_log_model and len (cfg .wandb_log_model ) > 0 :
35
41
os .environ ["WANDB_LOG_MODEL" ] = cfg .wandb_log_model
36
42
37
43
@@ -61,6 +67,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
61
67
if tokenizer .__class__ .__name__ == "LlamaTokenizer" :
62
68
tokenizer .pad_token = LLAMA_DEFAULT_PAD_TOKEN
63
69
70
+ if tokenizer .__class__ .__name__ == "GPTNeoXTokenizerFast" :
71
+ tokenizer .add_special_tokens ({'pad_token' : '[PAD]' })
72
+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
73
+
64
74
if cfg .load_in_8bit :
65
75
model = prepare_model_for_int8_training (model )
66
76
@@ -69,6 +79,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
69
79
lora_alpha = cfg .lora_alpha ,
70
80
target_modules = cfg .lora_target_modules ,
71
81
lora_dropout = cfg .lora_dropout ,
82
+ fan_in_fan_out = cfg .lora_fan_in_fan_out ,
72
83
bias = "none" ,
73
84
task_type = "CAUSAL_LM" ,
74
85
)
@@ -79,7 +90,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
79
90
# TODO resume_from_checkpoint handling
80
91
81
92
model .print_trainable_parameters ()
82
- return model , tokenizer
93
+ return model , tokenizer , lora_config
83
94
84
95
85
96
def train (
@@ -88,7 +99,7 @@ def train(
88
99
):
89
100
# load the config from the yaml file
90
101
with open (config , 'r' ) as f :
91
- cfg : AttrDict = AttrDict (yaml .load (f ))
102
+ cfg : AttrDict = AttrDict (yaml .load (f , Loader = yaml . Loader ))
92
103
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
93
104
# then overwrite the value
94
105
for k , v in enumerate (kwargs ):
@@ -107,23 +118,116 @@ def train(
107
118
setup_wandb_env_vars (cfg )
108
119
109
120
# Load the model and tokenizer
110
- model , tokenizer = load_model (cfg .base_model , cfg .model_type , cfg .tokenizer_type , cfg , adapter = cfg .adapter )
121
+ model , tokenizer , lora_config = load_model (cfg .base_model , cfg .model_type , cfg .tokenizer_type , cfg , adapter = cfg .adapter )
111
122
datasets = []
112
123
for d in cfg .datasets :
113
- ds : IterableDataset = load_dataset ("json" , data_files = d .path , streaming = True , num_proc = 4 , split = None )
124
+ ds : IterableDataset = load_dataset ("json" , data_files = d .path , streaming = True , split = None )
114
125
if d .type == "alpaca" :
115
126
ds_strategy = AlpacaPromptTokenizingStrategy (AlpacaPrompter (), tokenizer , cfg .train_on_inputs , cfg .sequence_len )
116
- ds_wrapper = TokenizedPromptDataset (ds_strategy , ds )
127
+ ds_wrapper = TokenizedPromptDataset (ds_strategy , ds [ "train" ] )
117
128
datasets .append (ds_wrapper )
118
129
elif d .type == "gpteacher" :
119
130
ds_strategy = GPTeacherPromptTokenizingStrategy (GPTeacherPrompter (), tokenizer , cfg .train_on_inputs , cfg .sequence_len )
120
- ds_wrapper = TokenizedPromptDataset (ds_strategy , ds )
131
+ ds_wrapper = TokenizedPromptDataset (ds_strategy , ds [ "train" ] )
121
132
datasets .append (ds_wrapper )
122
133
elif d .type == "sharegpt" :
123
134
ds_strategy = ShareGPTPromptTokenizingStrategy (ShareGPTPrompter (), tokenizer , cfg .train_on_inputs , cfg .sequence_len )
124
- ds_wrapper = TokenizedPromptDataset (ds_strategy , ds )
135
+ ds_wrapper = TokenizedPromptDataset (ds_strategy , ds [ "train" ] )
125
136
datasets .append (ds_wrapper )
137
+ constant_len_dataset = ConstantLengthDataset (tokenizer , datasets , seq_length = cfg .sequence_len )
138
+ constant_len_dataset = Dataset .from_list ([_ for _ in constant_len_dataset ]).train_test_split (
139
+ test_size = cfg .val_set_size , shuffle = True , seed = 42
140
+ )
141
+
142
+ print (constant_len_dataset )
143
+ train_dataset = constant_len_dataset ["train" ]
144
+ eval_dataset = constant_len_dataset ["test" ]
145
+
146
+ total_num_steps = int (math .ceil (len (train_dataset ) * cfg .num_epochs / cfg .batch_size ))
147
+ warmup_steps = min (int (0.03 * total_num_steps ), 100 )
148
+ logging_steps = min (int (0.005 * total_num_steps ), 10 )
149
+ save_steps = eval_steps = min (int (0.05 * total_num_steps ), 200 )
150
+
151
+ training_args = transformers .TrainingArguments (
152
+ per_device_train_batch_size = cfg .micro_batch_size ,
153
+ gradient_accumulation_steps = cfg .gradient_accumulation_steps ,
154
+ warmup_steps = warmup_steps ,
155
+ num_train_epochs = cfg .num_epochs ,
156
+ learning_rate = cfg .learning_rate ,
157
+ bf16 = cfg .bf16 ,
158
+ tf32 = cfg .tf32 ,
159
+ logging_steps = logging_steps ,
160
+ evaluation_strategy = "steps" if cfg .val_set_size > 0 else "no" ,
161
+ save_strategy = "steps" ,
162
+ eval_steps = eval_steps if cfg .val_set_size > 0 else None ,
163
+ save_steps = save_steps ,
164
+ output_dir = cfg .output_dir ,
165
+ save_total_limit = 3 ,
166
+ load_best_model_at_end = True if cfg .val_set_size > 0 else False ,
167
+ ddp_find_unused_parameters = False if cfg .ddp else None ,
168
+ group_by_length = cfg .group_by_length ,
169
+ report_to = "wandb" if cfg .use_wandb else None ,
170
+ run_name = cfg .wandb_run_name if cfg .use_wandb else None ,
171
+ )
172
+
173
+ decay_parameters = get_parameter_names (model , [nn .LayerNorm ])
174
+ decay_parameters = [name for name in decay_parameters if "bias" not in name ]
175
+ optimizer_grouped_parameters = [
176
+ {
177
+ "params" : [p for n , p in model .named_parameters () if n in decay_parameters ],
178
+ "weight_decay" : training_args .weight_decay ,
179
+ },
180
+ {
181
+ "params" : [p for n , p in model .named_parameters () if n not in decay_parameters ],
182
+ "weight_decay" : 0.0 ,
183
+ },
184
+ ]
185
+
186
+ adam_bnb_optim = bnb .optim .Adam8bit (
187
+ optimizer_grouped_parameters ,
188
+ betas = (training_args .adam_beta1 , training_args .adam_beta2 ),
189
+ eps = training_args .adam_epsilon ,
190
+ lr = training_args .learning_rate ,
191
+ )
192
+
193
+ lr_scheduler = transformers .get_cosine_schedule_with_warmup (
194
+ adam_bnb_optim ,
195
+ training_args .warmup_steps ,
196
+ total_num_steps ,
197
+ )
198
+
199
+ trainer = transformers .Trainer (
200
+ model = model ,
201
+ train_dataset = train_dataset ,
202
+ eval_dataset = eval_dataset ,
203
+ args = training_args ,
204
+ optimizers = (adam_bnb_optim , lr_scheduler ),
205
+ data_collator = transformers .DataCollatorForSeq2Seq (
206
+ tokenizer , pad_to_multiple_of = 8 , return_tensors = "pt" , padding = True
207
+ ),
208
+ )
209
+ model .config .use_cache = False
210
+
211
+ old_state_dict = model .state_dict
212
+ model .state_dict = (
213
+ lambda self , * _ , ** __ : get_peft_model_state_dict (
214
+ self , old_state_dict ()
215
+ )
216
+ ).__get__ (model , type (model ))
217
+
218
+ if torch .__version__ >= "2" and sys .platform != "win32" :
219
+ model = torch .compile (model )
220
+
221
+ signal .signal (signal .SIGINT , lambda signal , frame : (
222
+ model .save_pretrained (cfg .output_dir ),
223
+ exit (0 )
224
+ ))
225
+
226
+ # go ahead and presave the adapter config
227
+ lora_config .save_pretrained (cfg .output_dir )
228
+ trainer .train (resume_from_checkpoint = cfg .resume_from_checkpoint )
126
229
230
+ model .save_pretrained (cfg .output_dir )
127
231
128
232
if __name__ == "__main__" :
129
233
fire .Fire (train )
0 commit comments