Skip to content

Commit aae3c7e

Browse files
authored
Merge pull request #47 from arcee-ai/chore/training-experiment
Chore/training experiment
2 parents e3390f1 + 6a61ed4 commit aae3c7e

File tree

8 files changed

+143
-8
lines changed

8 files changed

+143
-8
lines changed

dalm/datasets/qa_gen/question_answer_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def generate_qa_from_dataset(
118118
# shuffle data
119119
dataset.shuffle(seed=42)
120120
# select a subset
121-
small_dataset = dataset.select(range(sample_size))
121+
num_samples = min(sample_size, len(dataset))
122+
small_dataset = dataset.select(range(num_samples))
122123
# train-test split
123124
small_dataset_splits = split_dataset(small_dataset, title_column_name)
124125
print(

dalm/training/rag_e2e/train_rage2e.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,8 @@ def train_e2e(
513513
)
514514
generator_tokenizer.save_pretrained(generator_ckpt_path)
515515
accelerator.wait_for_everyone()
516-
accelerator.end_training()
516+
if with_tracking:
517+
accelerator.end_training()
517518

518519

519520
def main() -> None:

dalm/training/retriever_only/train_retriever_only.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,20 +242,20 @@ def train_retriever(
242242
passage_max_len=passage_max_len,
243243
),
244244
batched=True,
245-
remove_columns=dataset["train"].column_names,
245+
remove_columns=dataset.column_names,
246246
desc="Running tokenizer on dataset",
247247
)
248248

249249
# Log a few random samples from the training set:
250-
for index in random.sample(range(len(processed_datasets["train"])), 3):
251-
logger.info(f"Sample {index} of the training set: {processed_datasets['train'][index]}.")
250+
for index in random.sample(range(len(processed_datasets)), 3):
251+
logger.info(f"Sample {index} of the training set: {processed_datasets[index]}.")
252252

253253
model.print_trainable_parameters() # type: ignore # No idea what mypy is complaining about.
254254
accelerator.print(model)
255255

256256
# get dataloaders
257257
train_dataloader = DataLoader(
258-
processed_datasets["train"],
258+
processed_datasets,
259259
shuffle=True,
260260
collate_fn=default_data_collator,
261261
batch_size=per_device_train_batch_size,
@@ -308,7 +308,7 @@ def train_retriever(
308308
accelerator.register_load_state_pre_hook(load_model_hook)
309309

310310
logger.info("***** Running training *****")
311-
logger.info(f" Num examples = {len(processed_datasets['train'])}")
311+
logger.info(f" Num examples = {len(processed_datasets)}")
312312
logger.info(f" Num Epochs = {num_train_epochs}")
313313
logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}")
314314
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
@@ -409,7 +409,8 @@ def train_retriever(
409409
)
410410
tokenizer.save_pretrained(output_dir)
411411
accelerator.wait_for_everyone()
412-
accelerator.end_training()
412+
if with_tracking:
413+
accelerator.end_training()
413414

414415

415416
def main() -> None:

experiments/README.md

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Testing e2e retriever hit rate
2+
3+
4+
Inspired by [llama-index](https://gpt-index.readthedocs.io/en/latest/examples/finetuning/embeddings/finetune_embedding_adapter.html)
5+
6+
# Setup
7+
8+
```shell
9+
pip install indomain
10+
pip install -r requirements.txt
11+
python data_gen.py
12+
mkdir qa-outputs
13+
```
14+
15+
## Create the datasets
16+
First, create the train dataset
17+
```shell
18+
dalm qa-gen train_data.csv --output-dir qa-outputs --passage-column-name text --title-column-name title --sample-size 1000000
19+
```
20+
This creates a train and test file (because we typically want to split), so merge those into 1
21+
```shell
22+
head -n 1 qa-outputs/question_answer_pairs_train.csv > question_answer_pairs.csv && tail -n+2 -q qa-outputs/*.csv >> question_answer_pairs.csv
23+
rm qa-outputs/*.csv
24+
mv question_answer_pairs.csv qa-outputs
25+
```
26+
27+
Same for the validation data
28+
```shell
29+
dalm qa-gen val_data.csv --output-dir qa-outputs-test --passage-column-name text --title-column-name title --sample-size 100000
30+
head -n 1 qa-outputs-test/question_answer_pairs_train.csv > question_answer_pairs_test.csv && tail -n+2 -q qa-outputs-test/*.csv >> question_answer_pairs_test.csv
31+
rm -rf qa-outputs-test
32+
mv question_answer_pairs_test.csv qa-outputs
33+
```
34+
35+
Now we have 2 files for training and eval
36+
```shell
37+
(.venv) root@f4ec1ae23983:# ls -lash qa-outputs/
38+
total 2.3M
39+
1.4M -rw-r--r-- 1 root root 1.4M Sep 20 20:02 question_answer_pairs.csv
40+
956K -rw-r--r-- 1 root root 953K Sep 20 20:14 question_answer_pairs_test.csv
41+
```
42+
43+
## Rage2e training
44+
45+
Then we train e2e
46+
```shell
47+
dalm train-rag-e2e \
48+
"qa-outputs/question_answer_pairs.csv" \
49+
"BAAI/bge-small-en" \
50+
"meta-llama/Llama-2-7b-hf" \
51+
--dataset-passage-col-name text \
52+
--output-dir "rag_e2e_checkpoints_bgsmall" \
53+
--no-with-tracking \
54+
--per-device-train-batch-size 12
55+
```
56+
57+
And eval
58+
```
59+
python ../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path rag_e2e_checkpoints_bgsmall/retriever --embed_dim 384
60+
61+
*************
62+
Retriever results:
63+
Recall: 0.8202054794520548
64+
Hit Rate: 0.8202054794520548
65+
*************
66+
```
67+
68+
69+
70+
## Retriever only training
71+
72+
Train the retriever only
73+
```
74+
dalm train-retriever-only "BAAI/bge-small-en" "qa-outputs/question_answer_pairs.csv" \
75+
--output-dir "retriever_only_checkpoints_bgsmall" \
76+
--use-peft \
77+
--dataset-passage-col-name text \
78+
--per-device-train-batch-size 150
79+
```
80+
81+
and eval
82+
```
83+
python ../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints_bgsmall/ --embed_dim 384
84+
85+
*************
86+
Retriever results:
87+
Recall: 0.8116438356164384
88+
Precision: 0.08116438356164453
89+
Hit Rate: 0.8116438356164384
90+
*************
91+
```

experiments/data_gen.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import json
2+
3+
from llama_index import SimpleDirectoryReader
4+
from llama_index.node_parser import SimpleNodeParser
5+
from llama_index.schema import MetadataMode
6+
import pandas as pd
7+
8+
TRAIN_FILES = ["uber_2021.pdf"]
9+
VAL_FILES = ["lyft_2021.pdf"]
10+
11+
12+
def load_corpus(files, verbose=False):
13+
if verbose:
14+
print(f"Loading files {files}")
15+
16+
reader = SimpleDirectoryReader(input_files=files)
17+
docs = reader.load_data()
18+
if verbose:
19+
print(f"Loaded {len(docs)} docs")
20+
21+
parser = SimpleNodeParser.from_defaults(chunk_size=512)
22+
nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)
23+
24+
if verbose:
25+
print(f"Parsed {len(nodes)} nodes")
26+
27+
return nodes
28+
29+
30+
31+
train_nodes = load_corpus(TRAIN_FILES, verbose=True)
32+
val_nodes = load_corpus(VAL_FILES, verbose=True)
33+
34+
train_df = pd.DataFrame({"text": [node.text for node in train_nodes], "title": [node.id_ for node in train_nodes]})
35+
val_df = pd.DataFrame({"text": [node.text for node in val_nodes], "title": [node.id_ for node in val_nodes]})
36+
37+
train_df.to_csv("train_data.csv")
38+
val_df.to_csv("val_data.csv")
39+

experiments/lyft_2021.pdf

1.37 MB
Binary file not shown.

experiments/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
llama-index
2+
pypdf

experiments/uber_2021.pdf

1.79 MB
Binary file not shown.

0 commit comments

Comments
 (0)