Skip to content

Commit

Permalink
Update scoring function
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelsty committed Aug 8, 2023
1 parent 83be549 commit 10ee52f
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 31 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ model = model.to(device)

optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=2e-6,
lr=2e-5,
)

flops_loss = losses.Flops()
Expand All @@ -69,15 +69,16 @@ for queries, documents, labels in utils.iter(
batch_size=batch_size,
shuffle=True,
):
queries_embeddings = model(queries, k=96)
queries_embeddings = model(queries, k=32)

documents_embeddings = model(documents, k=256)
documents_embeddings = model(documents, k=32)

scores = utils.scores(
queries_activations=queries_embeddings["activations"],
queries_embeddings=queries_embeddings["embeddings"],
documents_activations=documents_embeddings["activations"],
documents_embeddings=documents_embeddings["embeddings"],
device=device,
)

loss = cosine_loss.dense(
Expand Down Expand Up @@ -137,7 +138,7 @@ retriever = retrieve.Retriever(

retriever = retriever.add(
documents=documents,
k_token=64,
k_token=32, # Number of tokens to activate.
batch_size=3,
)

Expand All @@ -146,7 +147,8 @@ retriever(
"Apple",
"Banana",
],
k_sparse=64,
k_sparse=20, # Number of documents to retrieve.
k_token=32, # Number of tokens to activate.
batch_size=3
)
```
Expand Down
2 changes: 1 addition & 1 deletion sparsembed/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (0, 0, 2)
VERSION = (0, 0, 3)

__version__ = ".".join(map(str, VERSION))
1 change: 1 addition & 0 deletions sparsembed/losses/cosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Cosine(torch.nn.Module):
... queries_embeddings=queries_embeddings["embeddings"],
... documents_activations=documents_embeddings["activations"],
... documents_embeddings=documents_embeddings["embeddings"],
... device="cpu",
... )
>>> cosine_loss = losses.Cosine()
Expand Down
54 changes: 29 additions & 25 deletions sparsembed/retrieve/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Retriever:
... ]
>>> retriever = retriever.add(
... documents=documents,
... k_token=256,
... k_token=32,
... batch_size=24
... )
Expand All @@ -57,31 +57,31 @@ class Retriever:
... ]
>>> retriever = retriever.add(
... documents=documents,
... k_token=256,
... k_token=32,
... batch_size=24
... )
>>> print(retriever(["Food", "Sports", "Cinema", "Music", "Hello World"], k_token=96))
[[{'id': 0, 'similarity': 1.4686675071716309},
{'id': 1, 'similarity': 1.345913052558899},
{'id': 3, 'similarity': 1.304019808769226},
{'id': 2, 'similarity': 1.1579231023788452}],
[{'id': 1, 'similarity': 7.0373148918151855},
{'id': 3, 'similarity': 3.528376817703247},
{'id': 2, 'similarity': 2.4535036087036133},
{'id': 0, 'similarity': 1.7893059253692627}],
[{'id': 2, 'similarity': 2.3167333602905273},
{'id': 3, 'similarity': 2.2312183380126953},
{'id': 1, 'similarity': 2.0195937156677246},
{'id': 0, 'similarity': 1.2890148162841797}],
[{'id': 3, 'similarity': 2.4722704887390137},
{'id': 2, 'similarity': 1.8648046255111694},
{'id': 1, 'similarity': 1.732576608657837},
{'id': 0, 'similarity': 1.3416467905044556}],
[{'id': 3, 'similarity': 3.7778899669647217},
{'id': 2, 'similarity': 3.198120355606079},
{'id': 1, 'similarity': 3.1253902912139893},
{'id': 0, 'similarity': 2.458303451538086}]]
>>> print(retriever(["Food", "Sports", "Cinema", "Music", "Hello World"], k_token=32))
[[{'id': 3, 'similarity': 0.5633876323699951},
{'id': 2, 'similarity': 0.4271728992462158},
{'id': 1, 'similarity': 0.4205787181854248},
{'id': 0, 'similarity': 0.3673652410507202}],
[{'id': 1, 'similarity': 1.547836184501648},
{'id': 3, 'similarity': 0.7415981888771057},
{'id': 2, 'similarity': 0.6557919979095459},
{'id': 0, 'similarity': 0.5385637879371643}],
[{'id': 3, 'similarity': 0.5051844716072083},
{'id': 2, 'similarity': 0.48867619037628174},
{'id': 1, 'similarity': 0.3863832950592041},
{'id': 0, 'similarity': 0.2812037169933319}],
[{'id': 3, 'similarity': 0.9398075938224792},
{'id': 1, 'similarity': 0.595514178276062},
{'id': 2, 'similarity': 0.5711489319801331},
{'id': 0, 'similarity': 0.46095147728919983}],
[{'id': 2, 'similarity': 1.3963655233383179},
{'id': 3, 'similarity': 1.2879667282104492},
{'id': 1, 'similarity': 1.229896068572998},
{'id': 0, 'similarity': 1.2129783630371094}]]
"""

Expand All @@ -104,7 +104,9 @@ def __init__(
# Documents embeddings and activations store.
self.documents_embeddings, self.documents_activations = [], []
os.environ["TOKENIZERS_PARALLELISM"] = tokenizer_parallelism
warnings.filterwarnings('ignore', '.*Sparse CSR tensor support is in beta state.*')
warnings.filterwarnings(
"ignore", ".*Sparse CSR tensor support is in beta state.*"
)

def add(
self,
Expand Down Expand Up @@ -316,8 +318,8 @@ def _intersection(t1: torch.Tensor, t2: torch.Tensor) -> list[int]:
uniques, counts = combined.unique(return_counts=True, sorted=False)
return uniques[counts > 1].tolist()

@staticmethod
def _get_scores(
self,
queries_embeddings: list[torch.Tensor],
documents_embeddings: list[list[torch.Tensor]],
intersections: list[torch.Tensor],
Expand All @@ -337,6 +339,8 @@ def _get_scores(
dim=0,
)
)
if len(intersection) > 0
else torch.tensor(0.0, device=self.model.device)
for intersection, document_embddings in zip(
query_intersections, query_documents_embddings
)
Expand Down
7 changes: 7 additions & 0 deletions sparsembed/utils/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _get_scores(
queries_embeddings_index: torch.Tensor,
documents_embeddings_index: torch.Tensor,
intersections: torch.Tensor,
device: str,
func,
) -> list:
"""Computes similarity scores between queries and documents based on activated tokens embeddings"""
Expand All @@ -52,6 +53,8 @@ def _get_scores(
[query_embeddings_index[token] for token in intersection], dim=0
)
)
if len(intersection) > 0
else torch.tensor(0.0, device=device)
for query_embeddings_index, document_embeddings_index, intersection in zip(
queries_embeddings_index, documents_embeddings_index, intersections
)
Expand All @@ -65,6 +68,7 @@ def scores(
queries_embeddings: torch.Tensor,
documents_activations: torch.Tensor,
documents_embeddings: torch.Tensor,
device: str,
func=torch.mean,
) -> list:
"""Computes score between queries and documents intersected activated tokens.
Expand Down Expand Up @@ -111,6 +115,8 @@ def scores(
... queries_embeddings=queries_embeddings["embeddings"],
... documents_activations=documents_embeddings["activations"],
... documents_embeddings=documents_embeddings["embeddings"],
... func=torch.sum, # torch.sum is dedicated to training
... device="cpu",
... )
"""
Expand All @@ -132,4 +138,5 @@ def scores(
documents_embeddings_index=documents_embeddings_index,
intersections=intersections,
func=func,
device=device,
)

0 comments on commit 10ee52f

Please sign in to comment.