-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_2c_SBERT_embed.py
30 lines (24 loc) · 984 Bytes
/
_2c_SBERT_embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from sentence_transformers import SentenceTransformer
import csv
import pickle
import time
# load pre-trained SBERT model
model = SentenceTransformer('all-MiniLM-L6-v2')
# extract data from csv file to list
abstracts = []
with open("cleaned_data/abstracts_over150chars.txt") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for row in csv_reader:
abstracts.append(row[2])
#Compute embeddings
start_time = time.time()
embeddings = model.encode(abstracts,
convert_to_numpy = True, # allows downstream gpu or cpu similarity method
show_progress_bar = True,
normalize_embeddings = True) # allows faster dot product similarity method
# display elapsed time
end_time = time.time()
print("time taken: ", end_time - start_time)
# save abstract embeddings
with open("models/SBERT/over150chars/embeddings_nump_nrml.pkl", "wb") as file:
pickle.dump(embeddings, file)