-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Raphael Sourty
committed
Aug 21, 2023
1 parent
10ee52f
commit ece76ba
Showing
25 changed files
with
1,461 additions
and
603 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,167 +1,185 @@ | ||
<div align="center"> | ||
<h1>SparsEmbed</h1> | ||
<h1>SparsEmbed - Splade</h1> | ||
<p>Neural search</p> | ||
</div> | ||
|
||
This repository presents an unofficial replication of the research paper *[SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://research.google/pubs/pub52289/)* authored by Weize Kong, Jeffrey M. Dudek, Cheng Li, Mingyang Zhang, and Mike Bendersky, SIGIR 2023. | ||
This repository presents an unofficial replication of the research papers: | ||
|
||
**Note:** This project is currently a work in progress. 🔨🧹 | ||
- *[SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720)* authored by Thibault Formal, Benjamin Piwowarski, Stéphane Clinchant, SIGIR 2021. | ||
|
||
## Overview | ||
- *[SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://research.google/pubs/pub52289/)* authored by Weize Kong, Jeffrey M. Dudek, Cheng Li, Mingyang Zhang, and Mike Bendersky, SIGIR 2023. | ||
|
||
This repository aims to replicate the SparseEmbed model, focusing on learning both sparse lexical representations and contextual token level embeddings for retrieval tasks. | ||
|
||
The `SparsEmbed` model available here is compatible with any model compatible with the class `AutoModelForMaskedLM` from HuggingFace. | ||
|
||
### Differences with the original paper | ||
|
||
1. **Loss Function:** We did not yet implement the distillation loss used in the paper. We have initially opted for a cosine loss like the one used in SentenceTransformer library. This decision was made to fine-tune the model from scratch, avoiding the use of a cross-encoder as a teacher. The distillation loss should be available soon. | ||
|
||
2. **Multi-Head Implementation:** At this stage, the distinct MLM (Masked Language Model) head for document encoding and query encoding has not been incorporated. Our current implementation employs a shared MLM head (calculating sparse activations) for both documents and queries. | ||
**Note:** This project is currently a work in progress and models are not ready to use. 🔨🧹 | ||
|
||
## Installation | ||
|
||
``` | ||
pip install sparsembed | ||
``` | ||
|
||
If you plan to evaluate your model, install: | ||
|
||
``` | ||
pip install "sparsembed[eval]" | ||
``` | ||
|
||
## Training | ||
|
||
The following PyTorch code snippet illustrates the training loop to fine-tune the model: | ||
### Dataset | ||
|
||
Your training dataset must be made out of triples `(anchor, positive, negative)` where anchor is a query, positive is a document that is directly linked to the anchor and negative is a document that is not relevant for the query. | ||
|
||
```python | ||
from sparsembed import model, utils, losses | ||
from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
import torch | ||
X = [ | ||
("anchor 1", "positive 1", "negative 1"), | ||
("anchor 2", "positive 2", "negative 2"), | ||
("anchor 3", "positive 3", "negative 3"), | ||
] | ||
``` | ||
|
||
device = "cuda" # cpu / cuda | ||
batch_size = 32 | ||
### Models | ||
|
||
model = model.SparsEmbed( | ||
model=AutoModelForMaskedLM.from_pretrained("Luyu/co-condenser-marco").to(device), | ||
tokenizer=AutoTokenizer.from_pretrained("Luyu/co-condenser-marco"), | ||
device=device, | ||
) | ||
Both Splade and SparseEmbed models can be initialized from the `AutoModelForMaskedLM` pretrained models. | ||
|
||
model = model.to(device) | ||
```python | ||
from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
|
||
optimizer = torch.optim.AdamW( | ||
filter(lambda p: p.requires_grad, model.parameters()), | ||
lr=2e-5, | ||
model = model.Splade( | ||
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), | ||
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), | ||
device=device | ||
) | ||
``` | ||
|
||
flops_loss = losses.Flops() | ||
### Splade | ||
|
||
cosine_loss = losses.Cosine() | ||
The following PyTorch code snippet illustrates the training loop to fine-tune Splade: | ||
|
||
dataset = [ | ||
# Query, Document, Label (1: Relevant, 0: Not Relevant) | ||
("Apple", "Apple is a popular fruit.", 1), | ||
("Apple", "Banana is a popular fruit.", 0), | ||
("Banana", "Apple is a popular fruit.", 0), | ||
("Banana", "Banana is a yellow fruit.", 1), | ||
] | ||
```python | ||
from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
from sparsembed import model, utils, train, retrieve | ||
import torch | ||
|
||
for queries, documents, labels in utils.iter( | ||
dataset, | ||
device=device, | ||
epochs=1, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
): | ||
queries_embeddings = model(queries, k=32) | ||
|
||
documents_embeddings = model(documents, k=32) | ||
|
||
scores = utils.scores( | ||
queries_activations=queries_embeddings["activations"], | ||
queries_embeddings=queries_embeddings["embeddings"], | ||
documents_activations=documents_embeddings["activations"], | ||
documents_embeddings=documents_embeddings["embeddings"], | ||
device=device, | ||
) | ||
|
||
loss = cosine_loss.dense( | ||
scores=scores, | ||
labels=labels, | ||
) | ||
|
||
loss += 0.1 * cosine_loss.sparse( | ||
queries_sparse_activations=queries_embeddings["sparse_activations"], | ||
documents_sparse_activations=documents_embeddings["sparse_activations"], | ||
labels=labels, | ||
) | ||
|
||
loss += 4e-3 * flops_loss( | ||
sparse_activations=queries_embeddings["sparse_activations"] | ||
) | ||
loss += 4e-3 * flops_loss( | ||
sparse_activations=documents_embeddings["sparse_activations"] | ||
) | ||
|
||
loss.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
``` | ||
device = "cpu" # cuda | ||
|
||
## Inference | ||
batch_size = 3 | ||
|
||
Once we trained the model, we can initialize a `Retriever` to retrieve relevant documents given a query. | ||
model = model.Splade( | ||
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), | ||
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), | ||
device=device | ||
) | ||
|
||
- It build a sparse matrix from sparse activations of documents. | ||
- It build a sparse matrix from sparse activations of queries. | ||
- It match relevant documents using dot product of both sparse matrix. | ||
- It re-rank documents based on contextual embbedings similarity score. | ||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) | ||
|
||
```python | ||
from sparsembed import retrieve | ||
|
||
documents = [{ | ||
"id": 0, | ||
"document": "Apple is a popular fruit.", | ||
}, | ||
{ | ||
"id": 1, | ||
"document": "Banana is a popular fruit.", | ||
}, | ||
{ | ||
"id": 2, | ||
"document": "Banana is a yellow fruit.", | ||
} | ||
X = [ | ||
("anchor 1", "positive 1", "negative 1"), | ||
("anchor 2", "positive 2", "negative 2"), | ||
("anchor 3", "positive 3", "negative 3"), | ||
] | ||
|
||
retriever = retrieve.Retriever( | ||
key="id", | ||
on="document", | ||
model=model # Trained SparseEmbed model. | ||
for anchor, positive, negative in utils.iter( | ||
X, | ||
epochs=1, | ||
batch_size=batch_size, | ||
shuffle=True | ||
): | ||
loss = train.train_splade( | ||
model=model, | ||
optimizer=optimizer, | ||
anchor=anchor, | ||
positive=positive, | ||
negative=negative, | ||
flops_loss_weight=1e-5, | ||
in_batch_negatives=True, | ||
) | ||
|
||
documents, queries, qrels = utils.load_beir("scifact", split="test") | ||
|
||
retriever = retrieve.SpladeRetriever( | ||
key="id", | ||
on=["title", "text"], | ||
model=model | ||
) | ||
|
||
retriever = retriever.add( | ||
documents=documents, | ||
k_token=32, # Number of tokens to activate. | ||
batch_size=3, | ||
batch_size=batch_size | ||
) | ||
|
||
retriever( | ||
q = [ | ||
"Apple", | ||
"Banana", | ||
], | ||
k_sparse=20, # Number of documents to retrieve. | ||
k_token=32, # Number of tokens to activate. | ||
batch_size=3 | ||
utils.evaluate( | ||
retriever=retriever, | ||
batch_size=1, | ||
qrels=qrels, | ||
queries=queries, | ||
k=100, | ||
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"] | ||
) | ||
``` | ||
|
||
## SparsEmbed | ||
|
||
The following PyTorch code snippet illustrates the training loop to fine-tune SparseEmbed: | ||
|
||
```python | ||
[[{'id': 0, 'similarity': 195.057861328125}, | ||
{'id': 1, 'similarity': 183.51429748535156}, | ||
{'id': 2, 'similarity': 158.66012573242188}], | ||
[{'id': 1, 'similarity': 214.34048461914062}, | ||
{'id': 2, 'similarity': 194.5692901611328}, | ||
{'id': 0, 'similarity': 192.5744171142578}]] | ||
``` | ||
from transformers import AutoModelForMaskedLM, AutoTokenizer | ||
from sparsembed import model, utils, train, retrieve | ||
import torch | ||
|
||
device = "cpu" # cuda | ||
|
||
batch_size = 3 | ||
|
||
model = model.SparsEmbed( | ||
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device), | ||
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), | ||
device=device | ||
) | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) | ||
|
||
## Evaluations | ||
X = [ | ||
("anchor 1", "positive 1", "negative 1"), | ||
("anchor 2", "positive 2", "negative 2"), | ||
("anchor 3", "positive 3", "negative 3"), | ||
] | ||
|
||
for anchor, positive, negative in utils.iter( | ||
X, | ||
epochs=1, | ||
batch_size=batch_size, | ||
shuffle=True | ||
): | ||
loss = train.train_sparsembed( | ||
model=model, | ||
optimizer=optimizer, | ||
anchor=anchor, | ||
positive=positive, | ||
negative=negative, | ||
flops_loss_weight=1e-5, | ||
sparse_loss_weight=0.1, | ||
in_batch_negatives=True, | ||
) | ||
|
||
documents, queries, qrels = utils.load_beir("scifact", split="test") | ||
|
||
retriever = retrieve.SparsEmbedRetriever( | ||
key="id", | ||
on=["title", "text"], | ||
model=model | ||
) | ||
|
||
Work in progress. | ||
retriever = retriever.add( | ||
documents=documents, | ||
batch_size=batch_size | ||
) | ||
|
||
utils.evaluate( | ||
retriever=retriever, | ||
batch_size=1, | ||
qrels=qrels, | ||
queries=queries, | ||
k=100, | ||
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"] | ||
) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,12 @@ | |
"transformers >= 4.30.2", | ||
] | ||
|
||
eval = ["ranx >= 0.3.16", "faiss-beir >= 2.0.0"] | ||
|
||
setuptools.setup( | ||
name="sparsembed", | ||
version=f"{__version__}", | ||
license="MIT", | ||
license="", | ||
author="Raphael Sourty", | ||
author_email="[email protected]", | ||
description="Sparse Embeddings for Neural Search.", | ||
|
@@ -28,9 +30,13 @@ | |
"semantic search", | ||
"SparseEmbed", | ||
"Google Research", | ||
"SPLADE", | ||
], | ||
packages=setuptools.find_packages(), | ||
install_requires=base_packages, | ||
extras_require={ | ||
"eval": base_packages + eval, | ||
}, | ||
classifiers=[ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__all__ = ["data", "losses", "model", "retrieve", "utils"] | ||
__all__ = ["data", "losses", "model", "retrieve", "train", "utils"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
VERSION = (0, 0, 3) | ||
VERSION = (0, 0, 4) | ||
|
||
__version__ = ".".join(map(str, VERSION)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .cosine import Cosine | ||
from .flops import Flops | ||
from .ranking import Ranking | ||
|
||
__all__ = ["Flops", "Cosine"] | ||
__all__ = ["Flops", "Ranking"] |
Oops, something went wrong.