Skip to content

Commit

Permalink
Add splade
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Sourty committed Aug 21, 2023
1 parent 10ee52f commit ece76ba
Show file tree
Hide file tree
Showing 25 changed files with 1,461 additions and 603 deletions.
21 changes: 0 additions & 21 deletions LICENSE

This file was deleted.

262 changes: 140 additions & 122 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,167 +1,185 @@
<div align="center">
<h1>SparsEmbed</h1>
<h1>SparsEmbed - Splade</h1>
<p>Neural search</p>
</div>

This repository presents an unofficial replication of the research paper *[SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://research.google/pubs/pub52289/)* authored by Weize Kong, Jeffrey M. Dudek, Cheng Li, Mingyang Zhang, and Mike Bendersky, SIGIR 2023.
This repository presents an unofficial replication of the research papers:

**Note:** This project is currently a work in progress. 🔨🧹
- *[SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking](https://arxiv.org/abs/2107.05720)* authored by Thibault Formal, Benjamin Piwowarski, Stéphane Clinchant, SIGIR 2021.

## Overview
- *[SparseEmbed: Learning Sparse Lexical Representations with Contextual Embeddings for Retrieval](https://research.google/pubs/pub52289/)* authored by Weize Kong, Jeffrey M. Dudek, Cheng Li, Mingyang Zhang, and Mike Bendersky, SIGIR 2023.

This repository aims to replicate the SparseEmbed model, focusing on learning both sparse lexical representations and contextual token level embeddings for retrieval tasks.

The `SparsEmbed` model available here is compatible with any model compatible with the class `AutoModelForMaskedLM` from HuggingFace.

### Differences with the original paper

1. **Loss Function:** We did not yet implement the distillation loss used in the paper. We have initially opted for a cosine loss like the one used in SentenceTransformer library. This decision was made to fine-tune the model from scratch, avoiding the use of a cross-encoder as a teacher. The distillation loss should be available soon.

2. **Multi-Head Implementation:** At this stage, the distinct MLM (Masked Language Model) head for document encoding and query encoding has not been incorporated. Our current implementation employs a shared MLM head (calculating sparse activations) for both documents and queries.
**Note:** This project is currently a work in progress and models are not ready to use. 🔨🧹

## Installation

```
pip install sparsembed
```

If you plan to evaluate your model, install:

```
pip install "sparsembed[eval]"
```

## Training

The following PyTorch code snippet illustrates the training loop to fine-tune the model:
### Dataset

Your training dataset must be made out of triples `(anchor, positive, negative)` where anchor is a query, positive is a document that is directly linked to the anchor and negative is a document that is not relevant for the query.

```python
from sparsembed import model, utils, losses
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]
```

device = "cuda" # cpu / cuda
batch_size = 32
### Models

model = model.SparsEmbed(
model=AutoModelForMaskedLM.from_pretrained("Luyu/co-condenser-marco").to(device),
tokenizer=AutoTokenizer.from_pretrained("Luyu/co-condenser-marco"),
device=device,
)
Both Splade and SparseEmbed models can be initialized from the `AutoModelForMaskedLM` pretrained models.

model = model.to(device)
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer

optimizer = torch.optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=2e-5,
model = model.Splade(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
device=device
)
```

flops_loss = losses.Flops()
### Splade

cosine_loss = losses.Cosine()
The following PyTorch code snippet illustrates the training loop to fine-tune Splade:

dataset = [
# Query, Document, Label (1: Relevant, 0: Not Relevant)
("Apple", "Apple is a popular fruit.", 1),
("Apple", "Banana is a popular fruit.", 0),
("Banana", "Apple is a popular fruit.", 0),
("Banana", "Banana is a yellow fruit.", 1),
]
```python
from transformers import AutoModelForMaskedLM, AutoTokenizer
from sparsembed import model, utils, train, retrieve
import torch

for queries, documents, labels in utils.iter(
dataset,
device=device,
epochs=1,
batch_size=batch_size,
shuffle=True,
):
queries_embeddings = model(queries, k=32)

documents_embeddings = model(documents, k=32)

scores = utils.scores(
queries_activations=queries_embeddings["activations"],
queries_embeddings=queries_embeddings["embeddings"],
documents_activations=documents_embeddings["activations"],
documents_embeddings=documents_embeddings["embeddings"],
device=device,
)

loss = cosine_loss.dense(
scores=scores,
labels=labels,
)

loss += 0.1 * cosine_loss.sparse(
queries_sparse_activations=queries_embeddings["sparse_activations"],
documents_sparse_activations=documents_embeddings["sparse_activations"],
labels=labels,
)

loss += 4e-3 * flops_loss(
sparse_activations=queries_embeddings["sparse_activations"]
)
loss += 4e-3 * flops_loss(
sparse_activations=documents_embeddings["sparse_activations"]
)

loss.backward()
optimizer.step()
optimizer.zero_grad()
```
device = "cpu" # cuda

## Inference
batch_size = 3

Once we trained the model, we can initialize a `Retriever` to retrieve relevant documents given a query.
model = model.Splade(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
device=device
)

- It build a sparse matrix from sparse activations of documents.
- It build a sparse matrix from sparse activations of queries.
- It match relevant documents using dot product of both sparse matrix.
- It re-rank documents based on contextual embbedings similarity score.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

```python
from sparsembed import retrieve

documents = [{
"id": 0,
"document": "Apple is a popular fruit.",
},
{
"id": 1,
"document": "Banana is a popular fruit.",
},
{
"id": 2,
"document": "Banana is a yellow fruit.",
}
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]

retriever = retrieve.Retriever(
key="id",
on="document",
model=model # Trained SparseEmbed model.
for anchor, positive, negative in utils.iter(
X,
epochs=1,
batch_size=batch_size,
shuffle=True
):
loss = train.train_splade(
model=model,
optimizer=optimizer,
anchor=anchor,
positive=positive,
negative=negative,
flops_loss_weight=1e-5,
in_batch_negatives=True,
)

documents, queries, qrels = utils.load_beir("scifact", split="test")

retriever = retrieve.SpladeRetriever(
key="id",
on=["title", "text"],
model=model
)

retriever = retriever.add(
documents=documents,
k_token=32, # Number of tokens to activate.
batch_size=3,
batch_size=batch_size
)

retriever(
q = [
"Apple",
"Banana",
],
k_sparse=20, # Number of documents to retrieve.
k_token=32, # Number of tokens to activate.
batch_size=3
utils.evaluate(
retriever=retriever,
batch_size=1,
qrels=qrels,
queries=queries,
k=100,
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)
```

## SparsEmbed

The following PyTorch code snippet illustrates the training loop to fine-tune SparseEmbed:

```python
[[{'id': 0, 'similarity': 195.057861328125},
{'id': 1, 'similarity': 183.51429748535156},
{'id': 2, 'similarity': 158.66012573242188}],
[{'id': 1, 'similarity': 214.34048461914062},
{'id': 2, 'similarity': 194.5692901611328},
{'id': 0, 'similarity': 192.5744171142578}]]
```
from transformers import AutoModelForMaskedLM, AutoTokenizer
from sparsembed import model, utils, train, retrieve
import torch

device = "cpu" # cuda

batch_size = 3

model = model.SparsEmbed(
model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
device=device
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

## Evaluations
X = [
("anchor 1", "positive 1", "negative 1"),
("anchor 2", "positive 2", "negative 2"),
("anchor 3", "positive 3", "negative 3"),
]

for anchor, positive, negative in utils.iter(
X,
epochs=1,
batch_size=batch_size,
shuffle=True
):
loss = train.train_sparsembed(
model=model,
optimizer=optimizer,
anchor=anchor,
positive=positive,
negative=negative,
flops_loss_weight=1e-5,
sparse_loss_weight=0.1,
in_batch_negatives=True,
)

documents, queries, qrels = utils.load_beir("scifact", split="test")

retriever = retrieve.SparsEmbedRetriever(
key="id",
on=["title", "text"],
model=model
)

Work in progress.
retriever = retriever.add(
documents=documents,
batch_size=batch_size
)

utils.evaluate(
retriever=retriever,
batch_size=1,
qrels=qrels,
queries=queries,
k=100,
metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)
```
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
"transformers >= 4.30.2",
]

eval = ["ranx >= 0.3.16", "faiss-beir >= 2.0.0"]

setuptools.setup(
name="sparsembed",
version=f"{__version__}",
license="MIT",
license="",
author="Raphael Sourty",
author_email="[email protected]",
description="Sparse Embeddings for Neural Search.",
Expand All @@ -28,9 +30,13 @@
"semantic search",
"SparseEmbed",
"Google Research",
"SPLADE",
],
packages=setuptools.find_packages(),
install_requires=base_packages,
extras_require={
"eval": base_packages + eval,
},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand Down
2 changes: 1 addition & 1 deletion sparsembed/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__all__ = ["data", "losses", "model", "retrieve", "utils"]
__all__ = ["data", "losses", "model", "retrieve", "train", "utils"]
2 changes: 1 addition & 1 deletion sparsembed/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
VERSION = (0, 0, 3)
VERSION = (0, 0, 4)

__version__ = ".".join(map(str, VERSION))
4 changes: 2 additions & 2 deletions sparsembed/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .cosine import Cosine
from .flops import Flops
from .ranking import Ranking

__all__ = ["Flops", "Cosine"]
__all__ = ["Flops", "Ranking"]
Loading

0 comments on commit ece76ba

Please sign in to comment.