|
| 1 | +from datasets import load_dataset |
| 2 | +from torch.utils.data import IterableDataset |
| 3 | +from transformers import AutoTokenizer |
| 4 | +from typing import Iterator |
| 5 | + |
| 6 | +from preprocess import clean_comments, include, keep_only_content |
| 7 | + |
| 8 | +from chunk import chunk |
| 9 | + |
| 10 | +class CleanDataset(IterableDataset): |
| 11 | + TRAIN_SPLIT_NAME = "codeparrot/codeparrot-clean-train" |
| 12 | + VAL_SPLIT_NAME = "codeparrot/codeparrot-clean-valid" |
| 13 | + |
| 14 | + def __init__(self, train_split: bool, max_size: int = float("inf")): |
| 15 | + SPLIT_NAME = [CleanDataset.TRAIN_SPLIT_NAME, CleanDataset.VAL_SPLIT_NAME][int(train_split)] |
| 16 | + |
| 17 | + # Set max size |
| 18 | + self.max_size = max_size |
| 19 | + |
| 20 | + # Load dataset |
| 21 | + ds = load_dataset(SPLIT_NAME, |
| 22 | + streaming=True, |
| 23 | + split="train") # Invariant for BOTH train and val sets |
| 24 | + |
| 25 | + # Preprocessing |
| 26 | + ds = ds.filter(lambda x: x["path"].endswith(".py")) # Python only |
| 27 | + ds = ds.filter(lambda x: include(x["content"])) # DS imports only |
| 28 | + ds = ds.map(lambda x: {"content": clean_comments(x["content"])}) # Reformat code |
| 29 | + ds = ds.map(keep_only_content) # Smaller samples |
| 30 | + |
| 31 | + # Prepare for torch DataLoader |
| 32 | + ds = ds.with_format("torch") |
| 33 | + |
| 34 | + # Enforce max size |
| 35 | + ds = ds.take(max_size) |
| 36 | + |
| 37 | + self.ds = ds |
| 38 | + |
| 39 | + def generate(self) -> Iterator[str]: |
| 40 | + i = 0 # Tracks attempt number for exception reporting |
| 41 | + |
| 42 | + for code_file in self.ds: |
| 43 | + i += 1 |
| 44 | + |
| 45 | + # Yield when possible, skip and log when not |
| 46 | + try: |
| 47 | + yield code_file["content"] |
| 48 | + except StopIteration: |
| 49 | + break |
| 50 | + except Exception as e: |
| 51 | + print(f"[WARNING] Exception while loading sample {i+1}/{self.max_size}: {e}. Skipped item") |
| 52 | + continue |
| 53 | + |
| 54 | + def __iter__(self) -> Iterator[dict]: |
| 55 | + return self.generate() |
| 56 | + |
| 57 | + |
| 58 | +class ChunkedDataset(CleanDataset): |
| 59 | + def __init__(self, train_split: bool, max_size: int, tokenizer: AutoTokenizer, |
| 60 | + chunk_size: int = 256, chunk_overlap_len: int = 3, max_chunks: int = 128): |
| 61 | + |
| 62 | + super().__init__(train_split, max_size) |
| 63 | + |
| 64 | + self.tokenizer = tokenizer |
| 65 | + self.chunk_size = chunk_size |
| 66 | + self.overlapping_len = chunk_overlap_len |
| 67 | + self.max_chunks = max_chunks |
| 68 | + |
| 69 | + def generate(self) -> Iterator[dict]: |
| 70 | + count = 0 |
| 71 | + |
| 72 | + for text in super().generate(): |
| 73 | + # Attempt to chunk each code sample |
| 74 | + chunks = None |
| 75 | + try: |
| 76 | + chunks = chunk(inp=text, |
| 77 | + tokenizer=self.tokenizer, |
| 78 | + chunk_size=self.chunk_size, |
| 79 | + overlapping_len=self.overlapping_len, |
| 80 | + max_chunks=self.max_chunks) |
| 81 | + except Exception as e: |
| 82 | + print(f"[WARNING] Exception while chunking sample {count}/{self.max_size}: {e}. Skipped item") |
| 83 | + continue |
| 84 | + |
| 85 | + # Extract input ids and attention masks |
| 86 | + ids, mask = chunks["input_ids"], chunks["attention_mask"] |
| 87 | + |
| 88 | + # Yield each chunk, stopping if max_size is reached |
| 89 | + for i in range(ids.size()[0]): |
| 90 | + # Stop yielding if max_size is reached |
| 91 | + if count >= self.max_size: |
| 92 | + break |
| 93 | + |
| 94 | + # Yield |
| 95 | + yield { |
| 96 | + "input_ids": ids[i], |
| 97 | + "attention_mask": mask[i], |
| 98 | + "labels": ids[i].clone() |
| 99 | + } |
| 100 | + count += 1 |
| 101 | + |
| 102 | + # Stop generating new chunks if max_size is reached |
| 103 | + if count >= self.max_size: |
| 104 | + break |
| 105 | + |
| 106 | + |
| 107 | +# SAMPLE USAGE |
| 108 | +if __name__ == "__main__": |
| 109 | + try: |
| 110 | + tokenizer = AutoTokenizer.from_pretrained("./tokenizer_10M") |
| 111 | + except OSError as _: |
| 112 | + print("[WARNING] tokenizer_10M folder was not found, defaulting to GPT2") |
| 113 | + tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| 114 | + |
| 115 | + ds = ChunkedDataset( |
| 116 | + train_split=True, # Use training split |
| 117 | + max_size=1_000_000, # Provide up to 1 million samples (not files) |
| 118 | + tokenizer=tokenizer, # Set tokenizer |
| 119 | + chunk_size=256, # Max length of id/mask sequences is 256 |
| 120 | + chunk_overlap_len=3, # Chunks share 3 ids with the previous chunk |
| 121 | + max_chunks=128, # Max chunks per file |
| 122 | + ) |
| 123 | + |
| 124 | + # ChunkedDataset is iterable, so it can be directly passed to a DataLoader |
| 125 | + from torch.utils.data import DataLoader |
| 126 | + |
| 127 | + loader = DataLoader( |
| 128 | + dataset=ds, |
| 129 | + batch_size=16, |
| 130 | + # shuffle should NOT be set because the dataset has unknown length |
| 131 | + ) |
| 132 | + |
| 133 | + # Inspect a single element of this batch |
| 134 | + for batch in loader: |
| 135 | + print(batch) |
| 136 | + break |
0 commit comments