diff --git a/BERT_Training_Enhanced.py b/BERT_Training_Enhanced.py new file mode 100644 index 0000000..6946af3 --- /dev/null +++ b/BERT_Training_Enhanced.py @@ -0,0 +1,104 @@ +import argparse +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.cuda.amp import GradScaler, autocast +import torch + +from .model import BERT +from .trainer import BERTTrainer +from .dataset import BERTDataset, WordVocab + +# Import EarlyStopping if it's from an external module or library +from your_module_name import EarlyStopping # Replace 'your_module_name' with the actual module name + +def train(): + parser = argparse.ArgumentParser() + + parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset for training BERT") + parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set for evaluating the training set") + parser.add_argument("-v", "--vocab_path", required=True, type=str, help="path to the vocabulary model") + parser.add_argument("-o", "--output_path", required=True, type=str, help="output path for the BERT model") + + parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model") + parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers") + parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads") + parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence length") + + parser.add_argument("-b", "--batch_size", type=int, default=64, help="batch size") + parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs") + parser.add_argument("-w", "--num_workers", type=int, default=5, help="number of dataloader workers") + + parser.add_argument("--with_cuda", type=bool, default=True, help="train with CUDA: true or false") + parser.add_argument("--log_freq", type=int, default=10, help="print loss every n iterations") + parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in the corpus") + parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device IDs") + parser.add_argument("--on_memory", type=bool, default=True, help="load data on memory: true or false") + + parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of Adam") + parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight decay for Adam") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="Adam's first beta value") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="Adam's second beta value") + + # New features + parser.add_argument("--dynamic_lr", type=bool, default=True, help="use dynamic learning rate adjustment") + parser.add_argument("--early_stopping", type=bool, default=True, help="enable early stopping") + parser.add_argument("--patience", type=int, default=3, help="patience for early stopping") + parser.add_argument("--mixed_precision", type=bool, default=True, help="use mixed precision training") + parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="steps for gradient accumulation") + parser.add_argument("--data_augmentation", type=bool, default=False, help="apply data augmentation techniques") + + args = parser.parse_args() + + print("Loading Vocab", args.vocab_path) + vocab = WordVocab.load_vocab(args.vocab_path) + print("Vocab Size: ", len(vocab)) + + print("Loading Train Dataset", args.train_dataset) + train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len, + corpus_lines=args.corpus_lines, on_memory=args.on_memory, + data_augmentation=args.data_augmentation) + + print("Loading Test Dataset", args.test_dataset) + test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \ + if args.test_dataset is not None else None + + print("Creating Dataloader") + train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) + test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \ + if test_dataset is not None else None + + print("Building BERT model") + bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) + + print("Creating BERT Trainer") + trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, + lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, + with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq, + mixed_precision=args.mixed_precision, grad_accumulation_steps=args.grad_accumulation_steps) + + # Dynamic Learning Rate Adjustment + if args.dynamic_lr: + scheduler = ReduceLROnPlateau(trainer.optimizer, mode='min', factor=0.5, patience=args.patience, verbose=True) + + # Early Stopping + early_stopping = None + if args.early_stopping: + early_stopping = EarlyStopping(patience=args.patience, verbose=True) + + print("Training Start") + for epoch in range(args.epochs): + trainer.train(epoch) + + if test_data_loader is not None: + test_loss = trainer.test(epoch) + + if args.dynamic_lr: + scheduler.step(test_loss) + + if early_stopping is not None: + early_stopping(test_loss, trainer.model) + if early_stopping.early_stop: + print("Early stopping") + break + + trainer.save(epoch, args.output_path)