-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathembeddings.py
More file actions
64 lines (50 loc) · 1.96 KB
/
embeddings.py
File metadata and controls
64 lines (50 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
from sentence_transformers import util
from structlog import get_logger
import time # Add import for timing
import config
EMBEDDING_DIMENSIONALITY = 384
logger = get_logger()
def compute_source_similarity(source_1, source_2, function='cosine'):
if function == 'dot':
return util.dot_score(source_1, np.transpose(source_2))
elif function == 'cosine':
return util.pytorch_cos_sim(source_1, source_2)[0][0]
def get_source_representation_from_titles(titles, model):
num_titles = len(titles)
logger.info("get_source_representation_from_titles called", num_titles=num_titles)
if num_titles < config.MINIMUM_ARTICLE_HISTORY_SIZE:
logger.warn(
"Not enough titles for source representation",
num_titles=num_titles,
min_required=config.MINIMUM_ARTICLE_HISTORY_SIZE
)
return np.zeros((1, EMBEDDING_DIMENSIONALITY))
start_time = time.time()
embeddings = model.encode(titles)
end_time = time.time()
logger.info(
"Model encoding finished",
num_titles=num_titles,
duration_sec=round(end_time - start_time, 3)
)
return embeddings.mean(axis=0)
def compute_source_representation_from_articles(articles_df, publisher_id, model):
logger.info(
"compute_source_representation_from_articles called",
publisher_id=publisher_id,
dataframe_shape=articles_df.shape
)
start_time = time.time()
publisher_bucket_df = articles_df[articles_df.publisher_id == publisher_id]
end_time = time.time()
logger.info(
"DataFrame filtering finished",
publisher_id=publisher_id,
duration_sec=round(end_time - start_time, 3),
filtered_shape=publisher_bucket_df.shape
)
titles = [
title for title in publisher_bucket_df.title.to_numpy() if title is not None]
# Pass the model to the helper function for encoding
return get_source_representation_from_titles(titles, model)