Skip to content

Commit

Permalink
Refactor and clean up to keep the training more simplistic
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 9, 2024
1 parent 4b03d93 commit f2715ae
Showing 1 changed file with 8 additions and 30 deletions.
38 changes: 8 additions & 30 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,6 @@ def preprocess_data(data, tokenizer, max_length, overlap):
return sequences


def process_dataset(dataset, tokenizer, max_length, overlap):
"""
Process a dataset using the preprocess_data function.
Args:
dataset: The dataset to process.
tokenizer: Tokenizer object for encoding the data.
max_length (int): Maximum sequence length for each chunk.
overlap (int): Overlap size between consecutive chunks.
Returns:
list: List of preprocessed sequences from the entire dataset.
"""
all_sequences = []
for item in dataset:
text = (
item["text"] if "text" in item else item["content"]
) # Adjust based on dataset structure
sequences = preprocess_data(text, tokenizer, max_length, overlap)
all_sequences.extend(sequences)
return all_sequences


def compute_perplexity(loss):
return torch.exp(loss)

Expand Down Expand Up @@ -333,15 +310,18 @@ def main():

# Load datasets mentioned in the LongRoPE paper
pg19_dataset = load_dataset("pg19", split="train")
arxiv_dataset = load_dataset("arxiv_dataset", split="train")
github_dataset = load_dataset("github_dataset", split="train")

# Set parameters for data preprocessing
max_length = 65536
overlap = 4096

# Preprocess the data into sequences
sequences = preprocess_data(data, tokenizer, max_length, overlap)
logger.info("Preprocessing PG19 dataset...")
sequences = []
for item in pg19_dataset:
text = item["text"]
sequences.extend(preprocess_data(text, tokenizer, max_length, overlap))
logger.info(f"Total sequences after preprocessing: {len(sequences)}")

# Create target sequences (shifted by one token)
targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]
Expand Down Expand Up @@ -386,18 +366,16 @@ def main():

# Check for the latest checkpoint
latest_checkpoint = "checkpoint_latest.pt"

if os.path.exists(latest_checkpoint):
logger.info(f"Found checkpoint: {latest_checkpoint}")
resume_from_checkpoint = latest_checkpoint

else:
logger.info("No checkpoint found, starting training from scratch")
resume_from_checkpoint = None

# Extend the context window of the model
extended_model = model.extend_context(
data_path="../data/raw/enwik8.gz",
data=sequences,
target_length=2048000,
max_sequence_length=65536,
tokenizer=tokenizer,
Expand All @@ -409,7 +387,7 @@ def main():

# Recover performance on shorter contexts
recovered_model = extended_model.recover_short_context(
data_path="../data/raw/enwik8.gz",
data=sequences,
max_sequence_length=48192,
tokenizer=tokenizer,
)
Expand Down

0 comments on commit f2715ae

Please sign in to comment.