1
+ import nltk
2
+ from nltk .tokenize import word_tokenize
3
+ from rank_bm25 import BM25Okapi
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer , util
6
+ import pickle
7
+ # Download required NLTK data
8
+ nltk .download ('punkt' )
9
+
10
+ class BM25 ():
11
+ def __init__ (self , dataset , top_k = 5 ):
12
+ self .dataset = dataset
13
+ self .top_k = top_k
14
+ self .tokenized_corpus = [self .preprocess_text (doc ) for doc in dataset ['merge' ]]
15
+ # Function to preprocess and tokenize text
16
+ def preprocess_text (self , text ):
17
+ return word_tokenize (text .lower ())
18
+
19
+
20
+ # Function to perform a search query
21
+ def search (self , query , bm25 ):
22
+ tokenized_query = self .preprocess_text (query )
23
+ scores = bm25 .get_scores (tokenized_query )
24
+ results = []
25
+ top_n_indices = sorted (range (len (scores )), key = lambda i : scores [i ], reverse = True )[:self .top_k ]
26
+
27
+ video_ids = []
28
+ for i in top_n_indices :
29
+ results .append ((self .dataset ['merge' ][i ], scores [i ]))
30
+ video_ids .append (self .dataset .loc [i ]['video_id' ])
31
+ print (results )
32
+ return video_ids
33
+
34
+
35
+ def run (self , query ):
36
+ # Initialize BM25
37
+ bm25 = BM25Okapi (self .tokenized_corpus )
38
+ # Example query
39
+ # query = "CNCF Webinars"
40
+ video_ids = self .search (query , bm25 , )
41
+ return video_ids
42
+
43
+ class BIENCODER ():
44
+ def __init__ (self , dataset , embeddings , top_k = 5 ):
45
+ self .dataset = dataset
46
+ self .embeddings = embeddings
47
+ self .top_k = top_k
48
+ self .bi_encoder = SentenceTransformer ('multi-qa-MiniLM-L6-cos-v1' )
49
+ self .bi_encoder .max_seq_length = 256
50
+
51
+ def search (self , query ):
52
+ print ("Input question:" , query )
53
+ question_embedding = self .bi_encoder .encode (query , convert_to_tensor = True )
54
+ # question_embedding = question_embedding.cuda()
55
+ hits = util .semantic_search (question_embedding , self .embeddings , top_k = self .top_k )
56
+ hits = hits [0 ] # Get the hits for the first query
57
+ # print(hits)
58
+
59
+ # Output of top-5 hits from bi-encoder
60
+ print ("\n -------------------------\n " )
61
+ print ("Top-3 Bi-Encoder Retrieval hits" )
62
+ hits = sorted (hits , key = lambda x : x ['score' ], reverse = True )
63
+ video_ids = []
64
+ for hit in hits :
65
+ print ("\t {:.3f}\t {}" .format (hit ['score' ], self .dataset ['merge' ][hit ['corpus_id' ]]))
66
+ video_ids .append (self .dataset .loc [hit ['corpus_id' ]]['video_id' ])
67
+ return video_ids
68
+
69
+ if __name__ == "__main__" :
70
+ query = 'CNCF Webinars' ## input query
71
+ dataset = pd .read_csv ('data/cncf_video_summary_combine.csv' )
72
+ print ('Method 1: BM25 alg for semantic search:' )
73
+ bm25_search = BM25 (dataset , top_k = 5 )
74
+ video_ids = bm25_search .run (query )
75
+ print ('here' )
76
+ print (video_ids )
77
+
78
+ print ('Method 2: Deep learning for semantic search:' )
79
+ with open ('data/embedding.pkl' , 'rb' ) as f :
80
+ embeddings = pickle .load (f )
81
+ video_ids = BIENCODER (dataset , embeddings ).search (query )
82
+ print (video_ids )
0 commit comments