Skip to content

Commit 64c98ac

Browse files
committed
filter dataset
1 parent efa8f68 commit 64c98ac

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
from ragas.metrics.critique import AspectCritique
3+
from llama_index import load_index_from_storage
4+
from llama_index import StorageContext, set_global_service_context, ServiceContext
5+
from datasets import Dataset
6+
from langchain.embeddings import HuggingFaceEmbeddings
7+
from llama_index.storage.docstore import SimpleDocumentStore
8+
from llama_index.storage.index_store import SimpleIndexStore
9+
from llama_index.vector_stores import SimpleVectorStore
10+
import json
11+
from tqdm import tqdm
12+
import os
13+
14+
def write_to_json(filename,finetuning_dataset):
15+
16+
database = json.load(open(filename))
17+
database.extend(finetuning_dataset)
18+
with open(filename,'w') as file:
19+
json.dump(database, file, indent=4)
20+
21+
if __name__ == "__main__":
22+
23+
embed_model = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
24+
service_context = ServiceContext.from_defaults(embed_model=embed_model,)
25+
critic = AspectCritique(name="filter", definition="Does the submission contain information that can be derived from input?")
26+
dataset = Dataset.from_json("wikidata/indices/subset.json")
27+
filename = "finetuning_dataset.json"
28+
29+
if not os.path.exists(filename):
30+
with open(filename,"w") as file:
31+
json.dump([], file)
32+
33+
34+
35+
batch_size=100
36+
max_ragas_score = 0.8
37+
threshold=0.8
38+
for batch in tqdm(range(0,len(dataset)+1, batch_size)):
39+
datapath=f"./sample-{batch}.index/"
40+
# create storage context using default stores
41+
storage_context = StorageContext.from_defaults(
42+
docstore=SimpleDocumentStore.from_persist_dir(persist_dir=datapath),
43+
vector_store=SimpleVectorStore.from_persist_dir(persist_dir=datapath),
44+
index_store=SimpleIndexStore.from_persist_dir(persist_dir=datapath),
45+
)
46+
set_global_service_context(service_context)
47+
index = load_index_from_storage(storage_context)
48+
retriever = index.as_retriever(similarity_top_k=1)
49+
subsample = dataset.select(range(batch, min(len(dataset), batch+batch_size)))
50+
finetuning_dataset = []
51+
52+
try:
53+
for item in subsample:
54+
if item["ragas_score"] <= max_ragas_score:
55+
56+
node = retriever.retrieve(item["Answer"])[0]
57+
filter = critic.score_single({"question":node.get_content(),"answer":item["Answer"]})
58+
59+
# if node.get_score()>=threshold:
60+
if filter:
61+
pos_chunk = node.to_dict()
62+
else:
63+
continue
64+
65+
66+
67+
retrieved_chunks = item["chunks"]
68+
# hard negatives : till positive hash
69+
hard = True
70+
hard_negatives,negatives = [], []
71+
for node in retrieved_chunks:
72+
73+
if node["node"]["hash"] == pos_chunk["node"]["hash"]:
74+
hard = False
75+
76+
if hard:
77+
hard_negatives.append(node)
78+
else:
79+
negatives.append(node)
80+
81+
sample = {"Question":item["Question"], "Answer":item["Answer"],
82+
"positives":[pos_chunk["node"]["text"]],
83+
"negatives":[chunk["node"]["text"] for chunk in negatives],
84+
"hard_negatives":[chunk["node"]["text"] for chunk in hard_negatives]}
85+
finetuning_dataset.append(sample)
86+
87+
write_to_json(filename, finetuning_dataset)
88+
except Exception as e:
89+
print(e)
90+
91+
92+

0 commit comments

Comments
 (0)