Skip to content

Commit

Permalink
Merge pull request #47 from arcee-ai/chore/training-experiment
Browse files Browse the repository at this point in the history
Chore/training experiment
  • Loading branch information
Ben-Epstein authored Sep 20, 2023
2 parents e3390f1 + 6a61ed4 commit aae3c7e
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 8 deletions.
3 changes: 2 additions & 1 deletion dalm/datasets/qa_gen/question_answer_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def generate_qa_from_dataset(
# shuffle data
dataset.shuffle(seed=42)
# select a subset
small_dataset = dataset.select(range(sample_size))
num_samples = min(sample_size, len(dataset))
small_dataset = dataset.select(range(num_samples))
# train-test split
small_dataset_splits = split_dataset(small_dataset, title_column_name)
print(
Expand Down
3 changes: 2 additions & 1 deletion dalm/training/rag_e2e/train_rage2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,8 @@ def train_e2e(
)
generator_tokenizer.save_pretrained(generator_ckpt_path)
accelerator.wait_for_everyone()
accelerator.end_training()
if with_tracking:
accelerator.end_training()


def main() -> None:
Expand Down
13 changes: 7 additions & 6 deletions dalm/training/retriever_only/train_retriever_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,20 @@ def train_retriever(
passage_max_len=passage_max_len,
),
batched=True,
remove_columns=dataset["train"].column_names,
remove_columns=dataset.column_names,
desc="Running tokenizer on dataset",
)

# Log a few random samples from the training set:
for index in random.sample(range(len(processed_datasets["train"])), 3):
logger.info(f"Sample {index} of the training set: {processed_datasets['train'][index]}.")
for index in random.sample(range(len(processed_datasets)), 3):
logger.info(f"Sample {index} of the training set: {processed_datasets[index]}.")

model.print_trainable_parameters() # type: ignore # No idea what mypy is complaining about.
accelerator.print(model)

# get dataloaders
train_dataloader = DataLoader(
processed_datasets["train"],
processed_datasets,
shuffle=True,
collate_fn=default_data_collator,
batch_size=per_device_train_batch_size,
Expand Down Expand Up @@ -308,7 +308,7 @@ def train_retriever(
accelerator.register_load_state_pre_hook(load_model_hook)

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(processed_datasets['train'])}")
logger.info(f" Num examples = {len(processed_datasets)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
Expand Down Expand Up @@ -409,7 +409,8 @@ def train_retriever(
)
tokenizer.save_pretrained(output_dir)
accelerator.wait_for_everyone()
accelerator.end_training()
if with_tracking:
accelerator.end_training()


def main() -> None:
Expand Down
91 changes: 91 additions & 0 deletions experiments/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Testing e2e retriever hit rate


Inspired by [llama-index](https://gpt-index.readthedocs.io/en/latest/examples/finetuning/embeddings/finetune_embedding_adapter.html)

# Setup

```shell
pip install indomain
pip install -r requirements.txt
python data_gen.py
mkdir qa-outputs
```

## Create the datasets
First, create the train dataset
```shell
dalm qa-gen train_data.csv --output-dir qa-outputs --passage-column-name text --title-column-name title --sample-size 1000000
```
This creates a train and test file (because we typically want to split), so merge those into 1
```shell
head -n 1 qa-outputs/question_answer_pairs_train.csv > question_answer_pairs.csv && tail -n+2 -q qa-outputs/*.csv >> question_answer_pairs.csv
rm qa-outputs/*.csv
mv question_answer_pairs.csv qa-outputs
```

Same for the validation data
```shell
dalm qa-gen val_data.csv --output-dir qa-outputs-test --passage-column-name text --title-column-name title --sample-size 100000
head -n 1 qa-outputs-test/question_answer_pairs_train.csv > question_answer_pairs_test.csv && tail -n+2 -q qa-outputs-test/*.csv >> question_answer_pairs_test.csv
rm -rf qa-outputs-test
mv question_answer_pairs_test.csv qa-outputs
```

Now we have 2 files for training and eval
```shell
(.venv) root@f4ec1ae23983:# ls -lash qa-outputs/
total 2.3M
1.4M -rw-r--r-- 1 root root 1.4M Sep 20 20:02 question_answer_pairs.csv
956K -rw-r--r-- 1 root root 953K Sep 20 20:14 question_answer_pairs_test.csv
```

## Rage2e training

Then we train e2e
```shell
dalm train-rag-e2e \
"qa-outputs/question_answer_pairs.csv" \
"BAAI/bge-small-en" \
"meta-llama/Llama-2-7b-hf" \
--dataset-passage-col-name text \
--output-dir "rag_e2e_checkpoints_bgsmall" \
--no-with-tracking \
--per-device-train-batch-size 12
```

And eval
```
python ../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path rag_e2e_checkpoints_bgsmall/retriever --embed_dim 384
*************
Retriever results:
Recall: 0.8202054794520548
Hit Rate: 0.8202054794520548
*************
```



## Retriever only training

Train the retriever only
```
dalm train-retriever-only "BAAI/bge-small-en" "qa-outputs/question_answer_pairs.csv" \
--output-dir "retriever_only_checkpoints_bgsmall" \
--use-peft \
--dataset-passage-col-name text \
--per-device-train-batch-size 150
```

and eval
```
python ../dalm/eval/eval_retriever_only.py --dataset_path qa-outputs/question_answer_pairs_test.csv --retriever_model_name_or_path "BAAI/bge-small-en" --passage_column_name text --query_column_name Question --retriever_peft_model_path retriever_only_checkpoints_bgsmall/ --embed_dim 384
*************
Retriever results:
Recall: 0.8116438356164384
Precision: 0.08116438356164453
Hit Rate: 0.8116438356164384
*************
```
39 changes: 39 additions & 0 deletions experiments/data_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json

from llama_index import SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from llama_index.schema import MetadataMode
import pandas as pd

TRAIN_FILES = ["uber_2021.pdf"]
VAL_FILES = ["lyft_2021.pdf"]


def load_corpus(files, verbose=False):
if verbose:
print(f"Loading files {files}")

reader = SimpleDirectoryReader(input_files=files)
docs = reader.load_data()
if verbose:
print(f"Loaded {len(docs)} docs")

parser = SimpleNodeParser.from_defaults(chunk_size=512)
nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)

if verbose:
print(f"Parsed {len(nodes)} nodes")

return nodes



train_nodes = load_corpus(TRAIN_FILES, verbose=True)
val_nodes = load_corpus(VAL_FILES, verbose=True)

train_df = pd.DataFrame({"text": [node.text for node in train_nodes], "title": [node.id_ for node in train_nodes]})
val_df = pd.DataFrame({"text": [node.text for node in val_nodes], "title": [node.id_ for node in val_nodes]})

train_df.to_csv("train_data.csv")
val_df.to_csv("val_data.csv")

Binary file added experiments/lyft_2021.pdf
Binary file not shown.
2 changes: 2 additions & 0 deletions experiments/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
llama-index
pypdf
Binary file added experiments/uber_2021.pdf
Binary file not shown.

0 comments on commit aae3c7e

Please sign in to comment.