-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathbeam_retriever.py
105 lines (83 loc) · 3.19 KB
/
beam_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import pickle
class BEAM_HEAD:
def __init__(self, model_name="law-ai/InLegalBERT"):
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def get_embeddings(self, text):
encoded_input = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = self.model(**encoded_input)
last_hidden_state = output.last_hidden_state
embeddings = last_hidden_state.mean(dim=1)
return embeddings
def compute_similarity(self, text1, text2):
embeddings1 = self.get_embeddings(text1)
embeddings2 = self.get_embeddings(text2)
similarity = F.cosine_similarity(embeddings1, embeddings2)
return similarity.item()
def score(self, text1, text2):
return self.compute_similarity(text1, text2)
def save(self, filepath):
with open(filepath, 'wb') as f:
pickle.dump({'tokenizer': self.tokenizer, 'model': self.model}, f)
@classmethod
def load(cls, filepath):
with open(filepath, 'rb') as f:
data = pickle.load(f)
instance = cls()
instance.tokenizer = data['tokenizer']
instance.model = data['model']
return instance
class BeamRetriever:
def __init__(self, head, B = 1, K = 2):
self.head = head
self.B = B
self.K = K
def retrieve(self, question, passages):
S1 = set()
K = self.K
chains = []
ids = set()
for i in range(K+1):
if i == 0:
scores = [(self.head.compute_similarity(question, passages[ids]), ids) for ids in passages]
scores = sorted(scores, key=lambda x: x[0], reverse=True)
scores = scores[:self.B]
chain = []
for x in scores:
chain.append(([x[1]], x[0]))
ids.add(x[1])
S1.add(x[1])
chains.append(chain)
chains = chains[0]
continue
temp = []
for id in passages:
if id in ids:
continue
for chain in chains:
score = chain[1]
nodes_l = chain[0]
temp.append((nodes_l, score, id))
temp1 = []
for x in temp:
string = ''
score = x[1]
if x[2] in x[0]:
continue
for node in x[0]:
string += passages[node] + ' '
string += passages[x[2]]
new_score = self.head.compute_similarity(question, string)
score = score + new_score
temp1.append((x[0]+[x[2]], score))
sorted_temp = sorted(temp1, key=lambda x: x[1], reverse=True)
chains = sorted_temp[:self.B]
for x in chains:
for y in x[0]:
S1.add(y)
return S1