1+ import math
12import os
3+ import signal
24import sys
35from pathlib import Path
46
7+ import bitsandbytes as bnb
58import fire
69import torch
710import transformers
811import yaml
912from attrdict import AttrDict
10- from datasets import load_dataset , IterableDataset
13+ from datasets import load_dataset , IterableDataset , Dataset
1114from peft import (
1215 LoraConfig ,
1316 get_peft_model ,
14- prepare_model_for_int8_training ,
17+ prepare_model_for_int8_training , get_peft_model_state_dict ,
1518)
19+ from torch import nn
1620from transformers import AutoModelForCausalLM , AutoTokenizer
1721
1822# add src to the pythonpath so we don't need to pip install this
23+ from transformers .trainer_pt_utils import get_parameter_names
24+
1925project_root = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' ))
2026src_dir = os .path .join (project_root , 'src' )
2127sys .path .insert (0 , src_dir )
2228
23- from axolotl .datasets import TokenizedPromptDataset
29+ from axolotl .datasets import TokenizedPromptDataset , ConstantLengthDataset
2430from axolotl .prompt_tokenizers import AlpacaPromptTokenizingStrategy , ShareGPTPromptTokenizingStrategy , \
2531 LLAMA_DEFAULT_PAD_TOKEN , GPTeacherPromptTokenizingStrategy
2632from axolotl .prompters import AlpacaPrompter , GPTeacherPrompter , ShareGPTPrompter
@@ -29,9 +35,9 @@ def setup_wandb_env_vars(cfg):
2935 if len (cfg .wandb_project ) > 0 :
3036 os .environ ["WANDB_PROJECT" ] = cfg .wandb_project
3137 cfg .use_wandb = True
32- if len (cfg .wandb_watch ) > 0 :
38+ if cfg . wandb_watch and len (cfg .wandb_watch ) > 0 :
3339 os .environ ["WANDB_WATCH" ] = cfg .wandb_watch
34- if len (cfg .wandb_log_model ) > 0 :
40+ if cfg . wandb_log_model and len (cfg .wandb_log_model ) > 0 :
3541 os .environ ["WANDB_LOG_MODEL" ] = cfg .wandb_log_model
3642
3743
@@ -61,6 +67,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
6167 if tokenizer .__class__ .__name__ == "LlamaTokenizer" :
6268 tokenizer .pad_token = LLAMA_DEFAULT_PAD_TOKEN
6369
70+ if tokenizer .__class__ .__name__ == "GPTNeoXTokenizerFast" :
71+ tokenizer .add_special_tokens ({'pad_token' : '[PAD]' })
72+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
73+
6474 if cfg .load_in_8bit :
6575 model = prepare_model_for_int8_training (model )
6676
@@ -69,6 +79,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
6979 lora_alpha = cfg .lora_alpha ,
7080 target_modules = cfg .lora_target_modules ,
7181 lora_dropout = cfg .lora_dropout ,
82+ fan_in_fan_out = cfg .lora_fan_in_fan_out ,
7283 bias = "none" ,
7384 task_type = "CAUSAL_LM" ,
7485 )
@@ -79,7 +90,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
7990 # TODO resume_from_checkpoint handling
8091
8192 model .print_trainable_parameters ()
82- return model , tokenizer
93+ return model , tokenizer , lora_config
8394
8495
8596def train (
@@ -88,7 +99,7 @@ def train(
8899):
89100 # load the config from the yaml file
90101 with open (config , 'r' ) as f :
91- cfg : AttrDict = AttrDict (yaml .load (f ))
102+ cfg : AttrDict = AttrDict (yaml .load (f , Loader = yaml . Loader ))
92103 # if there are any options passed in the cli, if it is something that seems valid from the yaml,
93104 # then overwrite the value
94105 for k , v in enumerate (kwargs ):
@@ -107,23 +118,116 @@ def train(
107118 setup_wandb_env_vars (cfg )
108119
109120 # Load the model and tokenizer
110- model , tokenizer = load_model (cfg .base_model , cfg .model_type , cfg .tokenizer_type , cfg , adapter = cfg .adapter )
121+ model , tokenizer , lora_config = load_model (cfg .base_model , cfg .model_type , cfg .tokenizer_type , cfg , adapter = cfg .adapter )
111122 datasets = []
112123 for d in cfg .datasets :
113- ds : IterableDataset = load_dataset ("json" , data_files = d .path , streaming = True , num_proc = 4 , split = None )
124+ ds : IterableDataset = load_dataset ("json" , data_files = d .path , streaming = True , split = None )
114125 if d .type == "alpaca" :
115126 ds_strategy = AlpacaPromptTokenizingStrategy (AlpacaPrompter (), tokenizer , cfg .train_on_inputs , cfg .sequence_len )
116- ds_wrapper = TokenizedPromptDataset (ds_strategy , ds )
127+ ds_wrapper = TokenizedPromptDataset (ds_strategy , ds [ "train" ] )
117128 datasets .append (ds_wrapper )
118129 elif d .type == "gpteacher" :
119130 ds_strategy = GPTeacherPromptTokenizingStrategy (GPTeacherPrompter (), tokenizer , cfg .train_on_inputs , cfg .sequence_len )
120- ds_wrapper = TokenizedPromptDataset (ds_strategy , ds )
131+ ds_wrapper = TokenizedPromptDataset (ds_strategy , ds [ "train" ] )
121132 datasets .append (ds_wrapper )
122133 elif d .type == "sharegpt" :
123134 ds_strategy = ShareGPTPromptTokenizingStrategy (ShareGPTPrompter (), tokenizer , cfg .train_on_inputs , cfg .sequence_len )
124- ds_wrapper = TokenizedPromptDataset (ds_strategy , ds )
135+ ds_wrapper = TokenizedPromptDataset (ds_strategy , ds [ "train" ] )
125136 datasets .append (ds_wrapper )
137+ constant_len_dataset = ConstantLengthDataset (tokenizer , datasets , seq_length = cfg .sequence_len )
138+ constant_len_dataset = Dataset .from_list ([_ for _ in constant_len_dataset ]).train_test_split (
139+ test_size = cfg .val_set_size , shuffle = True , seed = 42
140+ )
141+
142+ print (constant_len_dataset )
143+ train_dataset = constant_len_dataset ["train" ]
144+ eval_dataset = constant_len_dataset ["test" ]
145+
146+ total_num_steps = int (math .ceil (len (train_dataset ) * cfg .num_epochs / cfg .batch_size ))
147+ warmup_steps = min (int (0.03 * total_num_steps ), 100 )
148+ logging_steps = min (int (0.005 * total_num_steps ), 10 )
149+ save_steps = eval_steps = min (int (0.05 * total_num_steps ), 200 )
150+
151+ training_args = transformers .TrainingArguments (
152+ per_device_train_batch_size = cfg .micro_batch_size ,
153+ gradient_accumulation_steps = cfg .gradient_accumulation_steps ,
154+ warmup_steps = warmup_steps ,
155+ num_train_epochs = cfg .num_epochs ,
156+ learning_rate = cfg .learning_rate ,
157+ bf16 = cfg .bf16 ,
158+ tf32 = cfg .tf32 ,
159+ logging_steps = logging_steps ,
160+ evaluation_strategy = "steps" if cfg .val_set_size > 0 else "no" ,
161+ save_strategy = "steps" ,
162+ eval_steps = eval_steps if cfg .val_set_size > 0 else None ,
163+ save_steps = save_steps ,
164+ output_dir = cfg .output_dir ,
165+ save_total_limit = 3 ,
166+ load_best_model_at_end = True if cfg .val_set_size > 0 else False ,
167+ ddp_find_unused_parameters = False if cfg .ddp else None ,
168+ group_by_length = cfg .group_by_length ,
169+ report_to = "wandb" if cfg .use_wandb else None ,
170+ run_name = cfg .wandb_run_name if cfg .use_wandb else None ,
171+ )
172+
173+ decay_parameters = get_parameter_names (model , [nn .LayerNorm ])
174+ decay_parameters = [name for name in decay_parameters if "bias" not in name ]
175+ optimizer_grouped_parameters = [
176+ {
177+ "params" : [p for n , p in model .named_parameters () if n in decay_parameters ],
178+ "weight_decay" : training_args .weight_decay ,
179+ },
180+ {
181+ "params" : [p for n , p in model .named_parameters () if n not in decay_parameters ],
182+ "weight_decay" : 0.0 ,
183+ },
184+ ]
185+
186+ adam_bnb_optim = bnb .optim .Adam8bit (
187+ optimizer_grouped_parameters ,
188+ betas = (training_args .adam_beta1 , training_args .adam_beta2 ),
189+ eps = training_args .adam_epsilon ,
190+ lr = training_args .learning_rate ,
191+ )
192+
193+ lr_scheduler = transformers .get_cosine_schedule_with_warmup (
194+ adam_bnb_optim ,
195+ training_args .warmup_steps ,
196+ total_num_steps ,
197+ )
198+
199+ trainer = transformers .Trainer (
200+ model = model ,
201+ train_dataset = train_dataset ,
202+ eval_dataset = eval_dataset ,
203+ args = training_args ,
204+ optimizers = (adam_bnb_optim , lr_scheduler ),
205+ data_collator = transformers .DataCollatorForSeq2Seq (
206+ tokenizer , pad_to_multiple_of = 8 , return_tensors = "pt" , padding = True
207+ ),
208+ )
209+ model .config .use_cache = False
210+
211+ old_state_dict = model .state_dict
212+ model .state_dict = (
213+ lambda self , * _ , ** __ : get_peft_model_state_dict (
214+ self , old_state_dict ()
215+ )
216+ ).__get__ (model , type (model ))
217+
218+ if torch .__version__ >= "2" and sys .platform != "win32" :
219+ model = torch .compile (model )
220+
221+ signal .signal (signal .SIGINT , lambda signal , frame : (
222+ model .save_pretrained (cfg .output_dir ),
223+ exit (0 )
224+ ))
225+
226+ # go ahead and presave the adapter config
227+ lora_config .save_pretrained (cfg .output_dir )
228+ trainer .train (resume_from_checkpoint = cfg .resume_from_checkpoint )
126229
230+ model .save_pretrained (cfg .output_dir )
127231
128232if __name__ == "__main__" :
129233 fire .Fire (train )
0 commit comments