Skip to content

Commit 6526d69

Browse files
authored
Merge pull request #8 from shahules786/finetune-emb
Embedding finetuning experiments
2 parents 540741b + 62385b3 commit 6526d69

File tree

7 files changed

+6532
-1
lines changed

7 files changed

+6532
-1
lines changed

Diff for: .gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# python
44
**/.ipynb_checkpoints
55
.python-version
6-
6+
finetune-embedding/wikidata
77
# integrations
88
**/.chroma/index
99
**/wandb

Diff for: finetune-embedding/data_creation.ipynb

+765
Large diffs are not rendered by default.

Diff for: finetune-embedding/finetune_dataset.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
2+
from ragas.metrics.critique import AspectCritique
3+
from llama_index import load_index_from_storage
4+
from llama_index import StorageContext, set_global_service_context, ServiceContext
5+
from datasets import Dataset
6+
from langchain.embeddings import HuggingFaceEmbeddings
7+
from llama_index.storage.docstore import SimpleDocumentStore
8+
from llama_index.storage.index_store import SimpleIndexStore
9+
from llama_index.vector_stores import SimpleVectorStore
10+
import json
11+
from tqdm import tqdm
12+
import os
13+
14+
def write_to_json(filename,finetuning_dataset):
15+
16+
database = json.load(open(filename))
17+
database.extend(finetuning_dataset)
18+
with open(filename,'w') as file:
19+
json.dump(database, file, indent=4)
20+
21+
if __name__ == "__main__":
22+
23+
embed_model = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
24+
service_context = ServiceContext.from_defaults(embed_model=embed_model,)
25+
critic = AspectCritique(name="filter", definition="Does the submission contain information that can be derived from input?")
26+
dataset = Dataset.from_json("wikidata/indices/subset.json")
27+
filename = "finetuning_dataset.json"
28+
29+
if not os.path.exists(filename):
30+
with open(filename,"w") as file:
31+
json.dump([], file)
32+
33+
34+
35+
batch_size=100
36+
max_ragas_score = 0.8
37+
threshold=0.8
38+
for batch in tqdm(range(0,len(dataset)+1, batch_size)):
39+
datapath=f"./sample-{batch}.index/"
40+
# create storage context using default stores
41+
storage_context = StorageContext.from_defaults(
42+
docstore=SimpleDocumentStore.from_persist_dir(persist_dir=datapath),
43+
vector_store=SimpleVectorStore.from_persist_dir(persist_dir=datapath),
44+
index_store=SimpleIndexStore.from_persist_dir(persist_dir=datapath),
45+
)
46+
set_global_service_context(service_context)
47+
index = load_index_from_storage(storage_context)
48+
retriever = index.as_retriever(similarity_top_k=1)
49+
subsample = dataset.select(range(batch, min(len(dataset), batch+batch_size)))
50+
finetuning_dataset = []
51+
52+
try:
53+
for item in subsample:
54+
if item["ragas_score"] <= max_ragas_score:
55+
56+
node = retriever.retrieve(item["Answer"])[0]
57+
filter = critic.score_single({"question":node.get_content(),"answer":item["Answer"]})
58+
59+
# if node.get_score()>=threshold:
60+
if filter:
61+
pos_chunk = node.to_dict()
62+
else:
63+
continue
64+
65+
66+
67+
retrieved_chunks = item["chunks"]
68+
# hard negatives : till positive hash
69+
hard = True
70+
hard_negatives,negatives = [], []
71+
for node in retrieved_chunks:
72+
73+
if node["node"]["hash"] == pos_chunk["node"]["hash"]:
74+
hard = False
75+
continue
76+
77+
if hard:
78+
hard_negatives.append(node)
79+
else:
80+
negatives.append(node)
81+
82+
sample = {"Question":item["Question"], "Answer":item["Answer"],
83+
"Context":item["Context"],
84+
"Conversation_no":item["Conversation_no"],
85+
"Turn_no":item["Turn_no"],
86+
"Positives":[pos_chunk["node"]["text"]],
87+
"Negatives":[chunk["node"]["text"] for chunk in negatives],
88+
"Hard_negatives":[chunk["node"]["text"] for chunk in hard_negatives]}
89+
finetuning_dataset.append(sample)
90+
91+
write_to_json(filename, finetuning_dataset)
92+
except Exception as e:
93+
print(e)
94+
95+
96+

0 commit comments

Comments
 (0)