Skip to content

Commit ce24f5e

Browse files
committed
WIP for axolotl trainer
1 parent e9da4b9 commit ce24f5e

16 files changed

+497
-1
lines changed

.editorconfig

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
root = true
2+
3+
[*]
4+
end_of_line = lf
5+
insert_final_newline = true
6+
trim_trailing_whitespace = true
7+
8+
[*.py]
9+
indent_style = space
10+
indent_size = 4
11+
12+
[**.yml]
13+
indent_style = space
14+
indent_size = 2

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
**/axolotl.egg-info
2+
**/__pycache__
3+
.idea

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# Axolotl
22

3-
### You know you're going to axolotl questions
3+
#### You know you're going to axolotl questions
44

55

6+
### Converting JSON data files to JSONL
67

8+
```shell
9+
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
10+
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
11+
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
12+
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
13+
```

configs/pythia_1_2B_alpaca.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
base_model: EleutherAI/pythia-1.4b-deduped
2+
model_type: GPTNeoXForCausalLM
3+
tokenizer_type: AutoTokenizer
4+
load_in_8bit: true
5+
datasets:
6+
- path: ./data/alpaca_data_gpt4.jsonl
7+
type: alpaca
8+
- path: ./data/vicuna_cleaned.jsonl
9+
type: sharegpt
10+
- path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl
11+
type: gpteacher
12+
- path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl
13+
type: gpteacher
14+
val_set_size: 0.05
15+
adapter: lora
16+
sequence_len: 2048
17+
lora_r: 16
18+
lora_alpha: 32
19+
lora_dropout: 0.05
20+
lora_target_modules:
21+
- q_proj
22+
- v_proj
23+
wandb_project:
24+
wandb_watch:
25+
wandb:run_name:
26+
wandb_log_model: checkpoint
27+
output_dir: ./lora-alpaca
28+
batch_size: 128
29+
micro_batch_size: 8
30+
num_epochs: 5
31+
learning_rate: 0.0003
32+
train_on_inputs: false
33+
bf16: True
34+
fp16: True
35+
resume_from_checkpoint:
36+
local_rank:
37+
deepspeed:

data/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
3+
```shell
4+
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o raw/alpaca_data_gpt4.json
5+
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o raw/vicuna_cleaned.json
6+
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o raw/gpt4-instruct-similarity-0.6-dataset.json
7+
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o raw/roleplay-similarity_0.6-instruct-dataset.json
8+
```

data/raw/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
**

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[build-system]
2+
requires = ["setuptools", "wheel"]
3+
build-backend = "setuptools.build_meta"

requirements.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
git+https://github.com/huggingface/transformers.git
2+
git+https://github.com/huggingface/peft.git
3+
attrdict
4+
fire
5+
PyYAML==6.0
6+
black

scripts/alpaca_json_to_jsonl.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
5+
import fire
6+
from typing import Optional
7+
8+
# add src to the pythonpath so we don't need to pip install this
9+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
10+
src_dir = os.path.join(project_root, 'src')
11+
sys.path.insert(0, src_dir)
12+
13+
from axolotl.convert import *
14+
15+
def main(
16+
input: Path,
17+
output: Optional[Path] = None,
18+
to_stdout: Optional[bool] = False,
19+
):
20+
file_reader = FileReader()
21+
if to_stdout or output is None:
22+
writer = StdoutWriter()
23+
else:
24+
writer = FileWriter(output)
25+
json_parser = JsonParser()
26+
jsonl_serializer = JsonlSerializer()
27+
28+
converter = JsonToJsonlConverter(
29+
file_reader, writer, json_parser, jsonl_serializer
30+
)
31+
32+
converter.convert(input, output)
33+
34+
35+
if __name__ == "__main__":
36+
fire.Fire(main)

scripts/finetune.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
5+
import fire
6+
import torch
7+
import transformers
8+
import yaml
9+
from attrdict import AttrDict
10+
from datasets import load_dataset, IterableDataset
11+
from peft import (
12+
LoraConfig,
13+
get_peft_model,
14+
prepare_model_for_int8_training,
15+
)
16+
from transformers import AutoModelForCausalLM, AutoTokenizer
17+
18+
# add src to the pythonpath so we don't need to pip install this
19+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
20+
src_dir = os.path.join(project_root, 'src')
21+
sys.path.insert(0, src_dir)
22+
23+
from axolotl.datasets import TokenizedPromptDataset
24+
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
25+
LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
26+
from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
27+
28+
def setup_wandb_env_vars(cfg):
29+
if len(cfg.wandb_project) > 0:
30+
os.environ["WANDB_PROJECT"] = cfg.wandb_project
31+
cfg.use_wandb = True
32+
if len(cfg.wandb_watch) > 0:
33+
os.environ["WANDB_WATCH"] = cfg.wandb_watch
34+
if len(cfg.wandb_log_model) > 0:
35+
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
36+
37+
38+
def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
39+
if adapter != "lora":
40+
raise NotImplementedError(f"{adapter} peft adapter not available")
41+
try:
42+
model = getattr(transformers, model_type).from_pretrained(
43+
base_model,
44+
load_in_8bit=cfg.load_in_8bit,
45+
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
46+
device_map=cfg.device_map,
47+
)
48+
except:
49+
model = AutoModelForCausalLM.from_pretrained(
50+
base_model,
51+
load_in_8bit=cfg.load_in_8bit,
52+
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
53+
device_map=cfg.device_map,
54+
)
55+
56+
try:
57+
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
58+
except:
59+
tokenizer = AutoTokenizer.from_pretrained(base_model)
60+
61+
if tokenizer.__class__.__name__ == "LlamaTokenizer":
62+
tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
63+
64+
if cfg.load_in_8bit:
65+
model = prepare_model_for_int8_training(model)
66+
67+
lora_config = LoraConfig(
68+
r=cfg.lora_r,
69+
lora_alpha=cfg.lora_alpha,
70+
target_modules=cfg.lora_target_modules,
71+
lora_dropout=cfg.lora_dropout,
72+
bias="none",
73+
task_type="CAUSAL_LM",
74+
)
75+
model = get_peft_model(model, lora_config)
76+
if cfg.ddp:
77+
model.to(f"cuda:{cfg.local_rank}")
78+
79+
# TODO resume_from_checkpoint handling
80+
81+
model.print_trainable_parameters()
82+
return model, tokenizer
83+
84+
85+
def train(
86+
config: Path = Path('configs/pythia_1_2B_alpaca.yml'),
87+
**kwargs,
88+
):
89+
# load the config from the yaml file
90+
with open(config, 'r') as f:
91+
cfg: AttrDict = AttrDict(yaml.load(f))
92+
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
93+
# then overwrite the value
94+
for k, v in enumerate(kwargs):
95+
if k in cfg:
96+
cfg.k = v
97+
98+
# setup some derived config / hyperparams
99+
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
100+
cfg.device_map = "auto"
101+
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
102+
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
103+
cfg.ddp = cfg.world_size != 1
104+
if cfg.ddp:
105+
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
106+
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps // cfg.world_size
107+
setup_wandb_env_vars(cfg)
108+
109+
# Load the model and tokenizer
110+
model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
111+
datasets = []
112+
for d in cfg.datasets:
113+
ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None)
114+
if d.type == "alpaca":
115+
ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
116+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
117+
datasets.append(ds_wrapper)
118+
elif d.type == "gpteacher":
119+
ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
120+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
121+
datasets.append(ds_wrapper)
122+
elif d.type == "sharegpt":
123+
ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
124+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
125+
datasets.append(ds_wrapper)
126+
127+
128+
if __name__ == "__main__":
129+
fire.Fire(train)

setup.cfg

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[metadata]
2+
name = axolotl
3+
version = 0.1.0
4+
description = You know you're going to axolotl questions
5+
author = Wing Lian
6+
author_email = [email protected]
7+
license = MIT
8+
9+
[options]
10+
package_dir =
11+
=src
12+
packages = find:
13+
install_requires =
14+
transformers @ git+https://github.com/huggingface/transformers.git@main
15+
peft @ git+https://github.com/huggingface/peft.git@main
16+
attrdict
17+
fire
18+
PyYAML == 6.0
19+
black
20+
21+
[options.packages.find]
22+
where = src
23+

src/axolotl/__init__.py

Whitespace-only changes.

src/axolotl/convert.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import json
2+
import sys
3+
4+
5+
class FileReader:
6+
def read(self, file_path):
7+
with open(file_path, "r") as file:
8+
return file.read()
9+
10+
11+
class FileWriter:
12+
def __init__(self, file_path):
13+
self.file_path = file_path
14+
15+
def write(self, content):
16+
with open(self.file_path, "w") as file:
17+
file.write(content)
18+
19+
20+
class StdoutWriter:
21+
def write(self, content):
22+
sys.stdout.write(content)
23+
sys.stdout.write("\n")
24+
25+
26+
class JsonParser:
27+
def parse(self, content):
28+
return json.loads(content)
29+
30+
31+
class JsonlSerializer:
32+
def serialize(self, data):
33+
lines = [json.dumps(item) for item in data]
34+
return "\n".join(lines)
35+
36+
37+
class JsonToJsonlConverter:
38+
def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
39+
self.file_reader = file_reader
40+
self.file_writer = file_writer
41+
self.json_parser = json_parser
42+
self.jsonl_serializer = jsonl_serializer
43+
44+
def convert(self, input_file_path, output_file_path):
45+
content = self.file_reader.read(input_file_path)
46+
data = self.json_parser.parse(content)
47+
jsonl_content = self.jsonl_serializer.serialize(data)
48+
self.file_writer.write(jsonl_content)
49+
50+

0 commit comments

Comments
 (0)