Skip to content

Commit a87f13d

Browse files
danaiamiraliNishant2022aarushis18dennisfarmerkhhan
authored
finish model training code (#41)
Co-authored-by: Amirali Danai <[email protected]> Co-authored-by: Nishant Dash <[email protected]> Co-authored-by: Aarushi Shah <[email protected]> Co-authored-by: Dennis <[email protected]> Co-authored-by: khhan <[email protected]> Co-authored-by: jechingliao43 <[email protected]> Co-authored-by: Jeffrey Lu <[email protected]> Co-authored-by: maxim12313 <[email protected]> Co-authored-by: Selina <[email protected]>
1 parent f5869d5 commit a87f13d

File tree

13 files changed

+631
-0
lines changed

13 files changed

+631
-0
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
# Model files
2+
tokenizer_10M
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]
47
*$py.class
8+
checkpoints/
9+
checkpoints/*
510

611
# C extensions
712
*.so

education/docker/Dockerfile

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
FROM python:3.10.15-slim-bullseye
2+
3+
WORKDIR /app
4+
5+
COPY requirements.txt requirements.txt
6+
7+
RUN pip install --compile --no-cache-dir -r requirements.txt
8+
9+
COPY basic_flask.py basic_flask.py
10+
11+
CMD ["flask", "--app", "basic_flask.py", "run", "-p", "8080"]
12+
13+
14+

education/docker/basic_flask.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from flask import Flask
2+
app = Flask(__name__)
3+
4+
5+
@app.route('/')
6+
def hello_world():
7+
return "<h1>Hello World from Containerized App</h1>"
8+
9+
10+
if __name__ == "__main__":
11+
app.run(debug=True)

education/docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Flask==3.0.3

src/backend/.gitkeep

Whitespace-only changes.

src/backend/lambda.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import base64
2+
import json
3+
import torch
4+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5+
6+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
8+
model = AutoModelForCausalLM.from_pretrained("trained/gpt2-728")
9+
model.to(device)
10+
model.eval()
11+
12+
13+
# Function to predict the next token
14+
def predict_next_token(input_text, model=model, tokenizer=tokenizer, max_length=50):
15+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
16+
17+
predicted_text = pipe(input_text)[0]["generated_text"]
18+
19+
input_text_len = len(input_text)
20+
21+
return predicted_text[input_text_len:]
22+
23+
24+
def handler(event, context):
25+
try:
26+
if event.get('isBase64Encoded', False):
27+
body = base64.b64decode(event['body']).decode('utf-8')
28+
else:
29+
body = event['body']
30+
except (KeyError, json.JSONDecodeError) as e:
31+
return {"statusCode": 400, "body": f"Error processing request: {str(e)}"}
32+
return {"statusCode": 200, "body": predict_next_token(body)}

src/model/chunk.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
## This file contains the function to chunk the input text into smaller pieces for the model to process
2+
## The input text is tokenized and then split into chunks of size chunk_size
3+
4+
from transformers import AutoTokenizer
5+
import torch
6+
7+
def chunk(inp: str,
8+
tokenizer: AutoTokenizer,
9+
chunk_size: int = 256,
10+
overlapping_len: int = 3,
11+
max_chunks: int = 128) -> torch.Tensor:
12+
13+
# Tokenize entire sample
14+
tokenized_txt = tokenizer(inp,
15+
return_tensors="pt"
16+
)["input_ids"].view(-1)
17+
token_len = len(tokenized_txt)
18+
19+
# Add chunks
20+
chunks = []
21+
last_padding_size = 0
22+
23+
for i in range(0, token_len-overlapping_len, chunk_size-overlapping_len):
24+
# Exit if max_chunks is met
25+
if len(chunks) >= max_chunks:
26+
break
27+
28+
# Create (potentially too short) new chunk
29+
new_chunk = tokenized_txt[i:i+chunk_size]
30+
31+
# Generate (potentially empty) padding
32+
padding = torch.full(
33+
size=(chunk_size - len(new_chunk), ),
34+
fill_value=tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
35+
)
36+
last_padding_size = max(0, chunk_size - len(new_chunk))
37+
38+
# Pad
39+
new_chunk = torch.cat((new_chunk, padding))
40+
41+
# Add new correctly-sized chunk
42+
chunks.append(new_chunk)
43+
44+
# Compile results
45+
input_ids = torch.stack(chunks)
46+
attention_mask = torch.ones_like(input_ids)
47+
48+
if last_padding_size >= 2:
49+
attention_mask[-1, -last_padding_size:] = 0
50+
51+
return {
52+
"input_ids": input_ids,
53+
"attention_mask": attention_mask
54+
}

src/model/dataset.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from datasets import load_dataset
2+
from torch.utils.data import IterableDataset
3+
from transformers import AutoTokenizer
4+
from typing import Iterator
5+
6+
from preprocess import clean_comments, include, keep_only_content
7+
8+
from chunk import chunk
9+
10+
class CleanDataset(IterableDataset):
11+
TRAIN_SPLIT_NAME = "codeparrot/codeparrot-clean-train"
12+
VAL_SPLIT_NAME = "codeparrot/codeparrot-clean-valid"
13+
14+
def __init__(self, train_split: bool, max_size: int = float("inf")):
15+
SPLIT_NAME = [CleanDataset.TRAIN_SPLIT_NAME, CleanDataset.VAL_SPLIT_NAME][int(train_split)]
16+
17+
# Set max size
18+
self.max_size = max_size
19+
20+
# Load dataset
21+
ds = load_dataset(SPLIT_NAME,
22+
streaming=True,
23+
split="train") # Invariant for BOTH train and val sets
24+
25+
# Preprocessing
26+
ds = ds.filter(lambda x: x["path"].endswith(".py")) # Python only
27+
ds = ds.filter(lambda x: include(x["content"])) # DS imports only
28+
ds = ds.map(lambda x: {"content": clean_comments(x["content"])}) # Reformat code
29+
ds = ds.map(keep_only_content) # Smaller samples
30+
31+
# Prepare for torch DataLoader
32+
ds = ds.with_format("torch")
33+
34+
# Enforce max size
35+
ds = ds.take(max_size)
36+
37+
self.ds = ds
38+
39+
def generate(self) -> Iterator[str]:
40+
i = 0 # Tracks attempt number for exception reporting
41+
42+
for code_file in self.ds:
43+
i += 1
44+
45+
# Yield when possible, skip and log when not
46+
try:
47+
yield code_file["content"]
48+
except StopIteration:
49+
break
50+
except Exception as e:
51+
print(f"[WARNING] Exception while loading sample {i+1}/{self.max_size}: {e}. Skipped item")
52+
continue
53+
54+
def __iter__(self) -> Iterator[dict]:
55+
return self.generate()
56+
57+
58+
class ChunkedDataset(CleanDataset):
59+
def __init__(self, train_split: bool, max_size: int, tokenizer: AutoTokenizer,
60+
chunk_size: int = 256, chunk_overlap_len: int = 3, max_chunks: int = 128):
61+
62+
super().__init__(train_split, max_size)
63+
64+
self.tokenizer = tokenizer
65+
self.chunk_size = chunk_size
66+
self.overlapping_len = chunk_overlap_len
67+
self.max_chunks = max_chunks
68+
69+
def generate(self) -> Iterator[dict]:
70+
count = 0
71+
72+
for text in super().generate():
73+
# Attempt to chunk each code sample
74+
chunks = None
75+
try:
76+
chunks = chunk(inp=text,
77+
tokenizer=self.tokenizer,
78+
chunk_size=self.chunk_size,
79+
overlapping_len=self.overlapping_len,
80+
max_chunks=self.max_chunks)
81+
except Exception as e:
82+
print(f"[WARNING] Exception while chunking sample {count}/{self.max_size}: {e}. Skipped item")
83+
continue
84+
85+
# Extract input ids and attention masks
86+
ids, mask = chunks["input_ids"], chunks["attention_mask"]
87+
88+
# Yield each chunk, stopping if max_size is reached
89+
for i in range(ids.size()[0]):
90+
# Stop yielding if max_size is reached
91+
if count >= self.max_size:
92+
break
93+
94+
# Yield
95+
yield {
96+
"input_ids": ids[i],
97+
"attention_mask": mask[i],
98+
"labels": ids[i].clone()
99+
}
100+
count += 1
101+
102+
# Stop generating new chunks if max_size is reached
103+
if count >= self.max_size:
104+
break
105+
106+
107+
# SAMPLE USAGE
108+
if __name__ == "__main__":
109+
try:
110+
tokenizer = AutoTokenizer.from_pretrained("./tokenizer_10M")
111+
except OSError as _:
112+
print("[WARNING] tokenizer_10M folder was not found, defaulting to GPT2")
113+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
114+
115+
ds = ChunkedDataset(
116+
train_split=True, # Use training split
117+
max_size=1_000_000, # Provide up to 1 million samples (not files)
118+
tokenizer=tokenizer, # Set tokenizer
119+
chunk_size=256, # Max length of id/mask sequences is 256
120+
chunk_overlap_len=3, # Chunks share 3 ids with the previous chunk
121+
max_chunks=128, # Max chunks per file
122+
)
123+
124+
# ChunkedDataset is iterable, so it can be directly passed to a DataLoader
125+
from torch.utils.data import DataLoader
126+
127+
loader = DataLoader(
128+
dataset=ds,
129+
batch_size=16,
130+
# shuffle should NOT be set because the dataset has unknown length
131+
)
132+
133+
# Inspect a single element of this batch
134+
for batch in loader:
135+
print(batch)
136+
break

src/model/preprocess.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import regex as re
2+
3+
# ----------------------- Implementation ---------------------
4+
def clean_doc(code: str, delim: str):
5+
string_pre = r"(\S+\s*=\s*)"
6+
between = rf"({delim})((?:(?!{delim})[\s\S])*)({delim})"
7+
8+
# this should not hit
9+
holder = "!|<1multiline1>|!"
10+
11+
# initially save mutliline strings with flag value
12+
code = re.sub(
13+
string_pre + between,
14+
rf"\1{holder}\3{holder}",
15+
code,
16+
)
17+
18+
# removing docs without caring about mutliline
19+
code = re.sub(between, "", code)
20+
21+
# fix flags back to original
22+
code = re.sub(re.escape(holder), delim, code)
23+
return code
24+
25+
26+
# Im not sure this was the best approach for doc comments
27+
# Clean inline comments and block comments
28+
def clean_comments(code: str) -> str:
29+
# remove comments with #
30+
code = re.sub(r"#[^\n]*", "", code)
31+
32+
# remove docs with """
33+
code = clean_doc(code, '"""')
34+
35+
# remove docs with '''
36+
code = clean_doc(code, "'''")
37+
38+
# get rid of trailing
39+
code = re.sub(r"\s*\n", "\n", code)
40+
return code
41+
42+
43+
# for lambda deciding what files to include as training data
44+
def include(content: str) -> bool:
45+
libraries = [
46+
"numpy",
47+
"pandas",
48+
"matplotlib",
49+
"sklearn",
50+
"tensorflow",
51+
"torch",
52+
"scipy",
53+
]
54+
55+
return bool(re.search("|".join(libraries), content))
56+
57+
58+
# Filters samples to return only code
59+
def keep_only_content(sample: dict) -> dict:
60+
return {"content": sample["content"]}
61+
62+
63+
# preview cleaning and check
64+
def preview():
65+
from dataset import CleanDataset
66+
67+
ds = CleanDataset(
68+
train_split=False,
69+
max_size=100
70+
)
71+
72+
for x in ds:
73+
print(x["content"])
74+
75+
def tests():
76+
input1 = """
77+
string1 = '''
78+
keep me
79+
'''
80+
81+
'''
82+
docstring1
83+
'''
84+
85+
'''docstring2'''
86+
87+
'''docstring3
88+
docstring3'''
89+
90+
string2 = '''keep me'''
91+
"""
92+
_ = clean_comments(input1)
93+
94+
input2 = '''
95+
string = """
96+
"hello there" 'hi there'
97+
"""
98+
99+
"""
100+
1234567890qwertyuiop[!@#$%^&*()].
101+
"""
102+
'''
103+
print(clean_comments(input1))
104+
print(clean_comments(input2))
105+
106+
107+
if __name__ == "__main__":
108+
tests()
109+

src/model/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
datasets == 3.1.0
2+
torch == 2.5.0+cu121
3+
transformers == 4.44.2

0 commit comments

Comments
 (0)