From f2715ae1e8c37ea86a04972cbccd4418b87c8b88 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Mon, 8 Jul 2024 22:15:06 -0700 Subject: [PATCH] Refactor and clean up to keep the training more simplistic --- train.py | 38 ++++++++------------------------------ 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/train.py b/train.py index ce2f71f..fa6ff56 100644 --- a/train.py +++ b/train.py @@ -125,29 +125,6 @@ def preprocess_data(data, tokenizer, max_length, overlap): return sequences -def process_dataset(dataset, tokenizer, max_length, overlap): - """ - Process a dataset using the preprocess_data function. - - Args: - dataset: The dataset to process. - tokenizer: Tokenizer object for encoding the data. - max_length (int): Maximum sequence length for each chunk. - overlap (int): Overlap size between consecutive chunks. - - Returns: - list: List of preprocessed sequences from the entire dataset. - """ - all_sequences = [] - for item in dataset: - text = ( - item["text"] if "text" in item else item["content"] - ) # Adjust based on dataset structure - sequences = preprocess_data(text, tokenizer, max_length, overlap) - all_sequences.extend(sequences) - return all_sequences - - def compute_perplexity(loss): return torch.exp(loss) @@ -333,15 +310,18 @@ def main(): # Load datasets mentioned in the LongRoPE paper pg19_dataset = load_dataset("pg19", split="train") - arxiv_dataset = load_dataset("arxiv_dataset", split="train") - github_dataset = load_dataset("github_dataset", split="train") # Set parameters for data preprocessing max_length = 65536 overlap = 4096 # Preprocess the data into sequences - sequences = preprocess_data(data, tokenizer, max_length, overlap) + logger.info("Preprocessing PG19 dataset...") + sequences = [] + for item in pg19_dataset: + text = item["text"] + sequences.extend(preprocess_data(text, tokenizer, max_length, overlap)) + logger.info(f"Total sequences after preprocessing: {len(sequences)}") # Create target sequences (shifted by one token) targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences] @@ -386,18 +366,16 @@ def main(): # Check for the latest checkpoint latest_checkpoint = "checkpoint_latest.pt" - if os.path.exists(latest_checkpoint): logger.info(f"Found checkpoint: {latest_checkpoint}") resume_from_checkpoint = latest_checkpoint - else: logger.info("No checkpoint found, starting training from scratch") resume_from_checkpoint = None # Extend the context window of the model extended_model = model.extend_context( - data_path="../data/raw/enwik8.gz", + data=sequences, target_length=2048000, max_sequence_length=65536, tokenizer=tokenizer, @@ -409,7 +387,7 @@ def main(): # Recover performance on shorter contexts recovered_model = extended_model.recover_short_context( - data_path="../data/raw/enwik8.gz", + data=sequences, max_sequence_length=48192, tokenizer=tokenizer, )