Skip to content

Commit

Permalink
Update notebook training code
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 12, 2024
1 parent 3c67e6f commit 013cf7a
Showing 1 changed file with 142 additions and 48 deletions.
190 changes: 142 additions & 48 deletions notebooks/01_LongRoPE_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -375,64 +375,158 @@
"outputs": [],
"source": [
"def main():\n",
" \"\"\"Main function to setup and run training.\"\"\"\n",
" \"\"\"\n",
" Main function to set up and run the LongRoPE model training process.\n",
" \"\"\"\n",
"\n",
" # Initialize Weights & Biases for experiment tracking\n",
" wandb.init(project=\"longrope\", entity=\"your-entity-name\")\n",
"\n",
" # Load and configure the tokenizer\n",
" tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
" tokenizer.model_max_length = 2048000 # Set the maximum sequence length for the tokenizer\n",
" data = load_data(\"../data/raw/enwik8.gz\")\n",
" tokenizer.model_max_length = 2048000 # Set maximum sequence length to 2048k tokens\n",
"\n",
" max_length = 65536\n",
" overlap = 4096\n",
" sequences = preprocess_data(data, tokenizer, max_length, overlap)\n",
" # Load the PG19 dataset\n",
" pg19_dataset = load_dataset(\"pg19\", split=\"train\")\n",
"\n",
" targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]\n",
" # Define sequence lengths for progressive training\n",
" sequence_lengths = [2048, 128000, 256000, 2048000]\n",
"\n",
" validate_targets(targets, tokenizer.vocab_size)\n",
" for length in sequence_lengths:\n",
" logger.info(f\"Training on sequence length: {length}\")\n",
"\n",
" print(f\"Validated: {validate_targets(targets, tokenizer.vocab_size)}\")\n",
" # Set parameters for data preprocessing\n",
" max_length = min(length, 65536)\n",
" overlap = 4096\n",
"\n",
" dataset = CustomDataset(sequences, targets)\n",
" train_size = int(0.8 * len(dataset))\n",
" val_size = len(dataset) - train_size\n",
" train_dataset, val_dataset = torch.utils.data.random_split(\n",
" dataset, [train_size, val_size]\n",
" )\n",
" # Preprocess the data into sequences\n",
" logger.info(f\"Preprocessing PG19 dataset for length {length}...\")\n",
" sequences = []\n",
" for item in pg19_dataset:\n",
" text = item[\"text\"]\n",
" sequences.extend(preprocess_data(text, tokenizer, max_length, overlap))\n",
" logger.info(f\"Total sequences after preprocessing: {len(sequences)}\")\n",
"\n",
" train_loader = DataLoader(\n",
" train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn\n",
" )\n",
" val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)\n",
"\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" \n",
" model = LongRoPEModel(\n",
" d_model=4096,\n",
" n_heads=32,\n",
" num_layers=6,\n",
" vocab_size=tokenizer.vocab_size,\n",
" max_len=2048000, # Set max_len to 2048k tokens\n",
" ).to(device)\n",
"\n",
" extended_model = model.extend_context(\n",
" data_path=\"../data/raw/enwik8.gz\",\n",
" target_length=2048000, # Set target_length to 2048k tokens\n",
" max_sequence_length=65536,\n",
" tokenizer=tokenizer,\n",
" population_size=64,\n",
" num_mutations=16,\n",
" num_crossovers=16,\n",
" max_iterations=10,\n",
" )\n",
" # Create target sequences (shifted by one token)\n",
" targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]\n",
"\n",
" recovered_model = model.recover_short_context(\n",
" data_path=\"../data/raw/enwik8.gz\",\n",
" max_sequence_length=48192,\n",
" tokenizer=tokenizer,\n",
" )\n",
" # Validate that all target indices are within the vocabulary size\n",
" validate_targets(targets, tokenizer.vocab_size)\n",
"\n",
" optimizer = optim.Adam(recovered_model.parameters(), lr=1e-4)\n",
" criterion = nn.CrossEntropyLoss()\n",
" # Create a custom dataset from sequences and targets\n",
" dataset = CustomDataset(sequences, targets)\n",
"\n",
" # Split the dataset into training and validation sets\n",
" train_size = int(0.8 * len(dataset))\n",
" val_size = len(dataset) - train_size\n",
" train_dataset, val_dataset = torch.utils.data.random_split(\n",
" dataset, [train_size, val_size]\n",
" )\n",
"\n",
" # Create data loaders for training and validation\n",
" train_loader = DataLoader(\n",
" train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn\n",
" )\n",
" val_loader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)\n",
"\n",
" # Initialize or extend the LongRoPE model based on the current sequence length\n",
" if length == 2048:\n",
" # Initialize the base LongRoPE model\n",
" model = LongRoPEModel(\n",
" d_model=4096,\n",
" n_heads=32,\n",
" num_layers=6,\n",
" vocab_size=tokenizer.vocab_size,\n",
" max_len=length,\n",
" )\n",
" else:\n",
" # Extend the context window of the model\n",
" model = model.extend_context(\n",
" data=sequences,\n",
" target_length=length,\n",
" max_sequence_length=max_length,\n",
" tokenizer=tokenizer,\n",
" population_size=64,\n",
" num_mutations=16,\n",
" num_crossovers=16,\n",
" max_iterations=10,\n",
" )\n",
"\n",
" train(recovered_model, train_loader, val_loader, optimizer, criterion, device)\n",
" # Set up optimizer, loss function, and learning rate scheduler\n",
" optimizer = optim.AdamW(model.parameters(), lr=1e-4)\n",
" criterion = nn.CrossEntropyLoss()\n",
" scheduler = CosineAnnealingLR(optimizer, T_max=10)\n",
"\n",
" # Prepare model, optimizer, data loaders, and scheduler for distributed training\n",
" model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare(\n",
" model, optimizer, train_loader, val_loader, scheduler\n",
" )\n",
"\n",
" # Check for the latest checkpoint specific to this sequence length\n",
" latest_checkpoint = f\"checkpoint_latest_{length}.pt\"\n",
" if os.path.exists(latest_checkpoint):\n",
" logger.info(f\"Found checkpoint for length {length}: {latest_checkpoint}\")\n",
" resume_from_checkpoint = latest_checkpoint\n",
" else:\n",
" logger.info(\n",
" f\"No checkpoint found for length {length}, starting training from scratch\"\n",
" )\n",
" resume_from_checkpoint = None\n",
"\n",
" # Perform training or fine-tuning based on the current sequence length\n",
" if length in [128000, 256000]:\n",
" # Fine-tuning for specific steps as mentioned in the LongRoPE paper\n",
" fine_tune_steps = 400 if length == 128000 else 600\n",
" train(\n",
" model,\n",
" train_loader,\n",
" val_loader,\n",
" optimizer,\n",
" criterion,\n",
" scheduler,\n",
" tokenizer,\n",
" epochs=1,\n",
" gradient_accumulation_steps=fine_tune_steps // len(train_loader),\n",
" resume_from_checkpoint=resume_from_checkpoint,\n",
" max_steps=fine_tune_steps,\n",
" )\n",
" else:\n",
" # Regular training for other sequence lengths\n",
" train(\n",
" model,\n",
" train_loader,\n",
" val_loader,\n",
" optimizer,\n",
" criterion,\n",
" scheduler,\n",
" tokenizer,\n",
" resume_from_checkpoint=resume_from_checkpoint,\n",
" )\n",
"\n",
" # Recover performance on shorter contexts after 256k extension\n",
" if length == 256000:\n",
" model = model.recover_short_context(\n",
" data=sequences,\n",
" max_sequence_length=48192,\n",
" tokenizer=tokenizer,\n",
" )\n",
"\n",
" # Add a simple validation step after short context recovery\n",
" model.eval()\n",
" with torch.no_grad():\n",
" val_loss = sum(\n",
" criterion(model(inputs), targets).item()\n",
" for inputs, targets in val_loader\n",
" ) / len(val_loader)\n",
" logger.info(f\"Validation loss after short context recovery: {val_loss:.4f}\")\n",
" wandb.log({\"short_context_val_loss\": val_loss})\n",
"\n",
" # Save the final model\n",
" accelerator.save_state(\"final_model.pt\")\n",
" wandb.save(\"final_model.pt\")\n",
"\n",
" # Finish logging and close the Weights & Biases run\n",
" wandb.finish()\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
Expand Down

0 comments on commit 013cf7a

Please sign in to comment.