1
+
2
+ from ragas .metrics .critique import AspectCritique
3
+ from llama_index import load_index_from_storage
4
+ from llama_index import StorageContext , set_global_service_context , ServiceContext
5
+ from datasets import Dataset
6
+ from langchain .embeddings import HuggingFaceEmbeddings
7
+ from llama_index .storage .docstore import SimpleDocumentStore
8
+ from llama_index .storage .index_store import SimpleIndexStore
9
+ from llama_index .vector_stores import SimpleVectorStore
10
+ import json
11
+ from tqdm import tqdm
12
+ import os
13
+
14
+ def write_to_json (filename ,finetuning_dataset ):
15
+
16
+ database = json .load (open (filename ))
17
+ database .extend (finetuning_dataset )
18
+ with open (filename ,'w' ) as file :
19
+ json .dump (database , file , indent = 4 )
20
+
21
+ if __name__ == "__main__" :
22
+
23
+ embed_model = HuggingFaceEmbeddings (model_name = "BAAI/bge-small-en-v1.5" )
24
+ service_context = ServiceContext .from_defaults (embed_model = embed_model ,)
25
+ critic = AspectCritique (name = "filter" , definition = "Does the submission contain information that can be derived from input?" )
26
+ dataset = Dataset .from_json ("wikidata/indices/subset.json" )
27
+ filename = "finetuning_dataset.json"
28
+
29
+ if not os .path .exists (filename ):
30
+ with open (filename ,"w" ) as file :
31
+ json .dump ([], file )
32
+
33
+
34
+
35
+ batch_size = 100
36
+ max_ragas_score = 0.8
37
+ threshold = 0.8
38
+ for batch in tqdm (range (0 ,len (dataset )+ 1 , batch_size )):
39
+ datapath = f"./sample-{ batch } .index/"
40
+ # create storage context using default stores
41
+ storage_context = StorageContext .from_defaults (
42
+ docstore = SimpleDocumentStore .from_persist_dir (persist_dir = datapath ),
43
+ vector_store = SimpleVectorStore .from_persist_dir (persist_dir = datapath ),
44
+ index_store = SimpleIndexStore .from_persist_dir (persist_dir = datapath ),
45
+ )
46
+ set_global_service_context (service_context )
47
+ index = load_index_from_storage (storage_context )
48
+ retriever = index .as_retriever (similarity_top_k = 1 )
49
+ subsample = dataset .select (range (batch , min (len (dataset ), batch + batch_size )))
50
+ finetuning_dataset = []
51
+
52
+ try :
53
+ for item in subsample :
54
+ if item ["ragas_score" ] <= max_ragas_score :
55
+
56
+ node = retriever .retrieve (item ["Answer" ])[0 ]
57
+ filter = critic .score_single ({"question" :node .get_content (),"answer" :item ["Answer" ]})
58
+
59
+ # if node.get_score()>=threshold:
60
+ if filter :
61
+ pos_chunk = node .to_dict ()
62
+ else :
63
+ continue
64
+
65
+
66
+
67
+ retrieved_chunks = item ["chunks" ]
68
+ # hard negatives : till positive hash
69
+ hard = True
70
+ hard_negatives ,negatives = [], []
71
+ for node in retrieved_chunks :
72
+
73
+ if node ["node" ]["hash" ] == pos_chunk ["node" ]["hash" ]:
74
+ hard = False
75
+ continue
76
+
77
+ if hard :
78
+ hard_negatives .append (node )
79
+ else :
80
+ negatives .append (node )
81
+
82
+ sample = {"Question" :item ["Question" ], "Answer" :item ["Answer" ],
83
+ "Context" :item ["Context" ],
84
+ "Conversation_no" :item ["Conversation_no" ],
85
+ "Turn_no" :item ["Turn_no" ],
86
+ "Positives" :[pos_chunk ["node" ]["text" ]],
87
+ "Negatives" :[chunk ["node" ]["text" ] for chunk in negatives ],
88
+ "Hard_negatives" :[chunk ["node" ]["text" ] for chunk in hard_negatives ]}
89
+ finetuning_dataset .append (sample )
90
+
91
+ write_to_json (filename , finetuning_dataset )
92
+ except Exception as e :
93
+ print (e )
94
+
95
+
96
+
0 commit comments